Skip to content

Commit 9e394ab

Browse files
committed
[mlir][linalg][python] Reject unsigned pooling on non-integer element types in Python
1 parent 03ef5fc commit 9e394ab

File tree

2 files changed

+60
-6
lines changed

2 files changed

+60
-6
lines changed

mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -532,9 +532,9 @@ def _binary_max_signed(self, lhs: Value, rhs: Value) -> Value:
532532
raise NotImplementedError("Unsupported 'max' operands: {lhs}, {rhs}")
533533

534534
def _binary_max_unsigned(self, lhs: Value, rhs: Value) -> Value:
535-
if _is_floating_point_type(lhs.type):
536-
return arith.MaximumFOp(lhs, rhs).result
537-
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
535+
if (
536+
_is_integer_type(lhs.type) and not _is_bool_type(lhs.type)
537+
) or _is_index_type(lhs.type):
538538
return arith.MaxUIOp(lhs, rhs).result
539539
raise NotImplementedError("Unsupported 'max_unsigned' operands: {lhs}, {rhs}")
540540

@@ -546,9 +546,9 @@ def _binary_min_signed(self, lhs: Value, rhs: Value) -> Value:
546546
raise NotImplementedError("Unsupported 'min' operands: {lhs}, {rhs}")
547547

548548
def _binary_min_unsigned(self, lhs: Value, rhs: Value) -> Value:
549-
if _is_floating_point_type(lhs.type):
550-
return arith.MinimumFOp(lhs, rhs).result
551-
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
549+
if (
550+
_is_integer_type(lhs.type) and not _is_bool_type(lhs.type)
551+
) or _is_index_type(lhs.type):
552552
return arith.MinUIOp(lhs, rhs).result
553553
raise NotImplementedError("Unsupported 'min_unsigned' operands: {lhs}, {rhs}")
554554

@@ -634,6 +634,12 @@ def _is_index_type(t: Type) -> bool:
634634
return IndexType.isinstance(t)
635635

636636

637+
def _is_bool_type(t: Type) -> bool:
638+
if not IntegerType.isinstance(t):
639+
return False
640+
return IntegerType(t).width == 1
641+
642+
637643
def _get_floating_point_width(t: Type) -> int:
638644
# TODO: Create a FloatType in the Python API and implement the switch
639645
# there.

mlir/test/python/dialects/linalg/opdsl/emit_pooling.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,3 +150,51 @@ def test_f32f32_min_pooling(input, shape, init_result):
150150

151151

152152
print(module)
153+
154+
with Context() as ctx, Location.unknown():
155+
module = Module.create()
156+
with InsertionPoint(module.body):
157+
f32 = F32Type.get()
158+
bool_t = IntegerType.get_signless(1)
159+
160+
# CHECK: bool_max_unsigned_error: Unsupported 'max_unsigned' operands
161+
@func.FuncOp.from_py_func(
162+
RankedTensorType.get((1, 4, 16, 1), f32),
163+
RankedTensorType.get((2, 2), f32),
164+
RankedTensorType.get((1, 2, 4, 1), bool_t),
165+
)
166+
def test_bool_i1_max_unsigned_pooling_error(input, shape, init_result):
167+
try:
168+
pooling_poly(
169+
input,
170+
shape,
171+
outs=[init_result],
172+
reduce=BinaryFn.max_unsigned,
173+
cast=TypeFn.cast_unsigned,
174+
strides=[2, 4],
175+
dilations=[1, 2],
176+
)
177+
except NotImplementedError as e:
178+
print(f"bool_max_unsigned_error: {e}")
179+
return init_result
180+
181+
# CHECK: float_max_unsigned_error: Unsupported 'max_unsigned' operands
182+
@func.FuncOp.from_py_func(
183+
RankedTensorType.get((1, 4, 16, 1), f32),
184+
RankedTensorType.get((2, 2), f32),
185+
RankedTensorType.get((1, 2, 4, 1), f32),
186+
)
187+
def test_f32f32_max_unsigned_pooling_error(input, shape, init_result):
188+
try:
189+
pooling_poly(
190+
input,
191+
shape,
192+
outs=[init_result],
193+
reduce=BinaryFn.max_unsigned,
194+
cast=TypeFn.cast_unsigned,
195+
strides=[2, 4],
196+
dilations=[1, 2],
197+
)
198+
except NotImplementedError as e:
199+
print(f"float_max_unsigned_error: {e}")
200+
return init_result

0 commit comments

Comments
 (0)