Skip to content

Commit b2b8d72

Browse files
authored
Do not require scalar for memref set value (#101)
1 parent 8023308 commit b2b8d72

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

mlir/extras/dialects/ext/memref.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,9 @@ def __setitem__(self, idx, val):
172172
idx[i] = constant(d, index=True, loc=loc)
173173

174174
if all(isinstance(d, Scalar) for d in idx) and len(idx) == len(self.shape):
175+
if isinstance(val, (int, float)):
176+
# TODO: this is an unchecked conversion
177+
val = Scalar(val, dtype=self.dtype)
175178
assert isinstance(val, Scalar), "coordinate insert requires scalar element"
176179
store(val, self, idx, loc=loc)
177180
else:

0 commit comments

Comments
 (0)