Skip to content

Commit b2c58ef

Browse files
authored
[TIR] Fix Data Type Mismatch (int64 vs int32) in T.match_buffer when Working with Scalar Buffers in TIR (#18466)
This PR is trying to fix issues #17392. The issue with `T.match_buffer` for scalar elements that was causing the int64 vs. int32 type mismatch error in TVM. Fix: - Safe Type Coercion: Allows automatic casting between integer types when they have the same number of lanes - Type Safety Preserved: Still rejects incompatible type combinations (int vs float, different lane counts) --------- Co-authored-by: cchung100m <[email protected]>
1 parent 12f3bb0 commit b2c58ef

File tree

2 files changed

+40
-2
lines changed

2 files changed

+40
-2
lines changed

src/tir/transforms/lower_match_buffer.cc

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,8 +220,15 @@ class MatchBufferLower : public StmtExprMutator {
220220
}
221221

222222
void Bind(const PrimExpr& arg, PrimExpr value, const std::string& arg_name = "argument") {
223-
CHECK_EQ(arg.dtype(), value.dtype())
224-
<< "The data type mismatched: " << arg->dtype << " vs. " << value->dtype;
223+
if (arg.dtype() != value.dtype()) {
224+
if (arg.dtype().is_int() && value.dtype().is_int() &&
225+
arg.dtype().lanes() == value.dtype().lanes()) {
226+
value = cast(arg.dtype(), value);
227+
} else {
228+
CHECK_EQ(arg.dtype(), value.dtype())
229+
<< "The data type mismatched: " << arg->dtype << " vs. " << value->dtype;
230+
}
231+
}
225232
// Handle recursive case
226233
value = Substitute(std::move(value), var_map_);
227234
if (arg->IsInstance<VarNode>()) {

tests/python/tir-transform/test_tir_transform_lower_match_buffer.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -532,5 +532,36 @@ def test_fail_match_func_param():
532532
_check_fail(fail_match_func_param)
533533

534534

535+
@T.prim_func
536+
def scalar_match_buffer_type_coercion(a: T.handle) -> None:
537+
A = T.match_buffer(a, (8, 8))
538+
for i, j in T.grid(8, 8):
539+
with T.block(""):
540+
vi = T.axis.spatial(8, i)
541+
vj = T.axis.spatial(8, j)
542+
T.reads()
543+
T.writes(A[vi, vj])
544+
# Create scalar match buffer from single element - this triggers type coercion
545+
scalar_buf = T.match_buffer(A[vi, vj], (), offset_factor=1)
546+
scalar_buf[()] = T.float32(1.0)
547+
548+
549+
@T.prim_func
550+
def transformed_scalar_match_buffer_type_coercion(a: T.handle) -> None:
551+
A = T.match_buffer(a, (8, 8))
552+
for i, j in T.grid(8, 8):
553+
with T.block(""):
554+
vi = T.axis.spatial(8, i)
555+
vj = T.axis.spatial(8, j)
556+
T.reads()
557+
T.writes(A[vi, vj])
558+
# Scalar match_buffer eliminated, direct assignment
559+
A[vi, vj] = T.float32(1.0)
560+
561+
562+
def test_scalar_match_buffer_type_coercion():
563+
_check(scalar_match_buffer_type_coercion, transformed_scalar_match_buffer_type_coercion)
564+
565+
535566
if __name__ == "__main__":
536567
tvm.testing.main()

0 commit comments

Comments
 (0)