@@ -349,53 +349,53 @@ inline ffi::Optional<PrimExpr> TryConstFold<tir::Max>(PrimExpr a, PrimExpr b) {
349349template <>
350350inline ffi::Optional<PrimExpr> TryConstFold<tir::GT>(PrimExpr a, PrimExpr b) {
351351 TVM_ARITH_CONST_PROPAGATION ({
352- if (pa && pb) return IntImm (DataType::UInt ( 1 ), pa->value > pb->value );
353- if (fa && fb) return IntImm (DataType::UInt ( 1 ), fa->value > fb->value );
352+ if (pa && pb) return IntImm (DataType::Bool ( ), pa->value > pb->value );
353+ if (fa && fb) return IntImm (DataType::Bool ( ), fa->value > fb->value );
354354 });
355355 return std::nullopt ;
356356}
357357
358358template <>
359359inline ffi::Optional<PrimExpr> TryConstFold<tir::GE>(PrimExpr a, PrimExpr b) {
360360 TVM_ARITH_CONST_PROPAGATION ({
361- if (pa && pb) return IntImm (DataType::UInt ( 1 ), pa->value >= pb->value );
362- if (fa && fb) return IntImm (DataType::UInt ( 1 ), fa->value >= fb->value );
361+ if (pa && pb) return IntImm (DataType::Bool ( ), pa->value >= pb->value );
362+ if (fa && fb) return IntImm (DataType::Bool ( ), fa->value >= fb->value );
363363 });
364364 return std::nullopt ;
365365}
366366
367367template <>
368368inline ffi::Optional<PrimExpr> TryConstFold<tir::LT>(PrimExpr a, PrimExpr b) {
369369 TVM_ARITH_CONST_PROPAGATION ({
370- if (pa && pb) return IntImm (DataType::UInt ( 1 ), pa->value < pb->value );
371- if (fa && fb) return IntImm (DataType::UInt ( 1 ), fa->value < fb->value );
370+ if (pa && pb) return IntImm (DataType::Bool ( ), pa->value < pb->value );
371+ if (fa && fb) return IntImm (DataType::Bool ( ), fa->value < fb->value );
372372 });
373373 return std::nullopt ;
374374}
375375
376376template <>
377377inline ffi::Optional<PrimExpr> TryConstFold<tir::LE>(PrimExpr a, PrimExpr b) {
378378 TVM_ARITH_CONST_PROPAGATION ({
379- if (pa && pb) return IntImm (DataType::UInt ( 1 ), pa->value <= pb->value );
380- if (fa && fb) return IntImm (DataType::UInt ( 1 ), fa->value <= fb->value );
379+ if (pa && pb) return IntImm (DataType::Bool ( ), pa->value <= pb->value );
380+ if (fa && fb) return IntImm (DataType::Bool ( ), fa->value <= fb->value );
381381 });
382382 return std::nullopt ;
383383}
384384
385385template <>
386386inline ffi::Optional<PrimExpr> TryConstFold<tir::EQ>(PrimExpr a, PrimExpr b) {
387387 TVM_ARITH_CONST_PROPAGATION ({
388- if (pa && pb) return IntImm (DataType::UInt ( 1 ), pa->value == pb->value );
389- if (fa && fb) return IntImm (DataType::UInt ( 1 ), fa->value == fb->value );
388+ if (pa && pb) return IntImm (DataType::Bool ( ), pa->value == pb->value );
389+ if (fa && fb) return IntImm (DataType::Bool ( ), fa->value == fb->value );
390390 });
391391 return std::nullopt ;
392392}
393393
394394template <>
395395inline ffi::Optional<PrimExpr> TryConstFold<tir::NE>(PrimExpr a, PrimExpr b) {
396396 TVM_ARITH_CONST_PROPAGATION ({
397- if (pa && pb) return IntImm (DataType::UInt ( 1 ), pa->value != pb->value );
398- if (fa && fb) return IntImm (DataType::UInt ( 1 ), fa->value != fb->value );
397+ if (pa && pb) return IntImm (DataType::Bool ( ), pa->value != pb->value );
398+ if (fa && fb) return IntImm (DataType::Bool ( ), fa->value != fb->value );
399399 });
400400 return std::nullopt ;
401401}
@@ -426,7 +426,7 @@ template <>
426426inline ffi::Optional<PrimExpr> TryConstFold<tir::Not>(PrimExpr a) {
427427 const IntImmNode* pa = a.as <IntImmNode>();
428428 if (pa) {
429- return IntImm (DataType::UInt ( 1 ), !(pa->value ));
429+ return IntImm (DataType::Bool ( ), !(pa->value ));
430430 }
431431 return std::nullopt ;
432432}
0 commit comments