Skip to content

Commit 660dd12

Browse files
authored
fix tensor coercion (#94)
1 parent 430f1c6 commit 660dd12

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

mlir/extras/dialects/ext/tensor.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,12 @@
88
from ._shaped_value import ShapedValue
99
from .arith import ArithValue, Scalar, constant
1010
from ... import types as T
11-
from ...util import _unpack_sizes_element_type, _update_caller_vars, get_user_code_loc
11+
from ...util import (
12+
_unpack_sizes_element_type,
13+
_update_caller_vars,
14+
get_user_code_loc,
15+
mlir_type_to_np_dtype,
16+
)
1217
from ...._mlir_libs._mlir import register_value_caster
1318
from ....dialects import tensor
1419
from ....dialects._ods_common import _dispatch_mixed_values, get_op_result_or_op_results
@@ -185,8 +190,12 @@ def coerce(
185190
f"can't coerce {other=} because {self=} doesn't have static shape"
186191
)
187192
if isinstance(other, (int, float)):
193+
np_dtype = mlir_type_to_np_dtype(self.dtype)
188194
other = Tensor(
189-
np.full(self.shape, other), dtype=self.dtype, loc=loc, ip=ip
195+
np.full(self.shape, other, dtype=np_dtype),
196+
dtype=self.dtype,
197+
loc=loc,
198+
ip=ip,
190199
)
191200
return self, other
192201
elif isinstance(other, Scalar):

0 commit comments

Comments
 (0)