Skip to content

Commit b6ac072

Browse files
authored
[DataType] Update to use explicit Bool Type Aligning with DLPack (#18453)
This PR updates the project to use explicit bool type which helps us to align with dlpack. It will also streamline explicit use of bool types.
1 parent f8471f8 commit b6ac072

30 files changed

+159
-122
lines changed

3rdparty/tvm-ffi

Submodule tvm-ffi updated 61 files

include/tvm/runtime/data_type.h

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ class DataType {
6060
kFloat = kDLFloat,
6161
kHandle = kDLOpaqueHandle,
6262
kBFloat = kDLBfloat,
63+
kBool = kDLBool,
6364
kFloat8_e3m4 = kDLFloat8_e3m4,
6465
kFloat8_e4m3 = kDLFloat8_e4m3,
6566
kFloat8_e4m3b11fnuz = kDLFloat8_e4m3b11fnuz,
@@ -137,8 +138,10 @@ class DataType {
137138
}
138139
/*! \return whether type is a scalar type. */
139140
bool is_scalar() const { return !is_scalable_vector() && lanes() == 1; }
140-
/*! \return whether type is a scalar type. */
141-
bool is_bool() const { return code() == DataType::kUInt && bits() == 1; }
141+
/*! \return whether type is a bool type. */
142+
bool is_bool() const { return code() == DataType::kBool; }
143+
/*! \return whether type can be used in a predicate expression. */
144+
bool is_predicate_dtype() const { return is_bool() || (is_uint() && bits() == 1); }
142145
/*! \return whether type is a float type. */
143146
bool is_float() const { return code() == DataType::kFloat; }
144147
/*! \return whether type is a bfloat type. */
@@ -204,7 +207,7 @@ class DataType {
204207
/*! \return whether type is a vector type. */
205208
bool is_vector() const { return lanes() > 1; }
206209
/*! \return whether type is a bool vector type. */
207-
bool is_vector_bool() const { return is_scalable_or_fixed_length_vector() && bits() == 1; }
210+
bool is_vector_bool() const { return is_scalable_or_fixed_length_vector() && is_bool(); }
208211
/*! \return whether type is a Void type. */
209212
bool is_void() const {
210213
return code() == DataType::kHandle && bits() == 0 && static_cast<int16_t>(data_.lanes) == 0;
@@ -381,7 +384,7 @@ class DataType {
381384
* \return The constructed data type.
382385
*/
383386
static DataType Bool(int lanes = 1, bool is_scalable = false) {
384-
return DataType::UInt(1, lanes, is_scalable);
387+
return DataType(kDLBool, 8, lanes, is_scalable);
385388
}
386389
/*!
387390
* \brief Construct a handle type.

include/tvm/tir/op.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -816,7 +816,7 @@ inline PrimExpr make_zero(DataType t, Span span = Span());
816816
* \return The result expression.
817817
*/
818818
inline PrimExpr const_true(int lanes = 1, Span span = Span()) {
819-
return make_const(DataType::UInt(1, lanes), 1);
819+
return make_const(DataType::Bool(lanes), 1);
820820
}
821821
/*!
822822
* \brief Make a constant false expression.
@@ -825,7 +825,7 @@ inline PrimExpr const_true(int lanes = 1, Span span = Span()) {
825825
* \return The result expression.
826826
*/
827827
inline PrimExpr const_false(int lanes = 1, Span span = Span()) {
828-
return make_const(DataType::UInt(1, lanes), 0);
828+
return make_const(DataType::Bool(lanes), 0);
829829
}
830830
/*!
831831
* \brief Get x as constant int expression.
@@ -957,7 +957,7 @@ inline bool is_no_op(const tir::Stmt& stmt) {
957957

958958
template <typename ValueType>
959959
inline PrimExpr MakeConstScalar(DataType t, ValueType value, Span span = Span()) {
960-
if (t.is_int()) return IntImm(t, static_cast<int64_t>(value), span);
960+
if (t.is_int() || t.is_bool()) return IntImm(t, static_cast<int64_t>(value), span);
961961
if (t.is_uint()) {
962962
// Use IntImm if it is a small integer
963963
uint64_t uval = static_cast<uint64_t>(value);

python/tvm/script/parser/tir/operation.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def _auto_broadcast(a, b, op):
6161
if (
6262
DataType(b.dtype).type_code == DataTypeCode.INT
6363
or DataType(b.dtype).type_code == DataTypeCode.UINT
64+
or DataType(b.dtype).type_code == DataTypeCode.BOOL
6465
):
6566
a = IntImm(_get_type_str(b.dtype), a)
6667
elif DataType(b.dtype).type_code == DataTypeCode.FLOAT:
@@ -80,6 +81,7 @@ def _auto_broadcast(a, b, op):
8081
if (
8182
DataType(a.dtype).type_code == DataTypeCode.INT
8283
or DataType(a.dtype).type_code == DataTypeCode.UINT
84+
or DataType(a.dtype).type_code == DataTypeCode.BOOL
8385
):
8486
b = IntImm(_get_type_str(a.dtype), b)
8587
elif DataType(a.dtype).type_code == DataTypeCode.FLOAT:

python/tvm/tir/ir_builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,7 @@ def allocate(self, dtype, shape, name="buf", axis_separators=None, scope=""):
448448
)
449449

450450
buffer_var = buffer.data
451-
self.emit(lambda x: _stmt.Allocate(buffer_var, dtype, shape, const(1, dtype="uint1"), x))
451+
self.emit(lambda x: _stmt.Allocate(buffer_var, dtype, shape, const(1, dtype="bool"), x))
452452
return BufferVar(self, buffer, dtype)
453453

454454
def pointer(self, content_type, name="ptr", scope=""):

src/arith/const_fold.h

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -349,53 +349,53 @@ inline ffi::Optional<PrimExpr> TryConstFold<tir::Max>(PrimExpr a, PrimExpr b) {
349349
template <>
350350
inline ffi::Optional<PrimExpr> TryConstFold<tir::GT>(PrimExpr a, PrimExpr b) {
351351
TVM_ARITH_CONST_PROPAGATION({
352-
if (pa && pb) return IntImm(DataType::UInt(1), pa->value > pb->value);
353-
if (fa && fb) return IntImm(DataType::UInt(1), fa->value > fb->value);
352+
if (pa && pb) return IntImm(DataType::Bool(), pa->value > pb->value);
353+
if (fa && fb) return IntImm(DataType::Bool(), fa->value > fb->value);
354354
});
355355
return std::nullopt;
356356
}
357357

358358
template <>
359359
inline ffi::Optional<PrimExpr> TryConstFold<tir::GE>(PrimExpr a, PrimExpr b) {
360360
TVM_ARITH_CONST_PROPAGATION({
361-
if (pa && pb) return IntImm(DataType::UInt(1), pa->value >= pb->value);
362-
if (fa && fb) return IntImm(DataType::UInt(1), fa->value >= fb->value);
361+
if (pa && pb) return IntImm(DataType::Bool(), pa->value >= pb->value);
362+
if (fa && fb) return IntImm(DataType::Bool(), fa->value >= fb->value);
363363
});
364364
return std::nullopt;
365365
}
366366

367367
template <>
368368
inline ffi::Optional<PrimExpr> TryConstFold<tir::LT>(PrimExpr a, PrimExpr b) {
369369
TVM_ARITH_CONST_PROPAGATION({
370-
if (pa && pb) return IntImm(DataType::UInt(1), pa->value < pb->value);
371-
if (fa && fb) return IntImm(DataType::UInt(1), fa->value < fb->value);
370+
if (pa && pb) return IntImm(DataType::Bool(), pa->value < pb->value);
371+
if (fa && fb) return IntImm(DataType::Bool(), fa->value < fb->value);
372372
});
373373
return std::nullopt;
374374
}
375375

376376
template <>
377377
inline ffi::Optional<PrimExpr> TryConstFold<tir::LE>(PrimExpr a, PrimExpr b) {
378378
TVM_ARITH_CONST_PROPAGATION({
379-
if (pa && pb) return IntImm(DataType::UInt(1), pa->value <= pb->value);
380-
if (fa && fb) return IntImm(DataType::UInt(1), fa->value <= fb->value);
379+
if (pa && pb) return IntImm(DataType::Bool(), pa->value <= pb->value);
380+
if (fa && fb) return IntImm(DataType::Bool(), fa->value <= fb->value);
381381
});
382382
return std::nullopt;
383383
}
384384

385385
template <>
386386
inline ffi::Optional<PrimExpr> TryConstFold<tir::EQ>(PrimExpr a, PrimExpr b) {
387387
TVM_ARITH_CONST_PROPAGATION({
388-
if (pa && pb) return IntImm(DataType::UInt(1), pa->value == pb->value);
389-
if (fa && fb) return IntImm(DataType::UInt(1), fa->value == fb->value);
388+
if (pa && pb) return IntImm(DataType::Bool(), pa->value == pb->value);
389+
if (fa && fb) return IntImm(DataType::Bool(), fa->value == fb->value);
390390
});
391391
return std::nullopt;
392392
}
393393

394394
template <>
395395
inline ffi::Optional<PrimExpr> TryConstFold<tir::NE>(PrimExpr a, PrimExpr b) {
396396
TVM_ARITH_CONST_PROPAGATION({
397-
if (pa && pb) return IntImm(DataType::UInt(1), pa->value != pb->value);
398-
if (fa && fb) return IntImm(DataType::UInt(1), fa->value != fb->value);
397+
if (pa && pb) return IntImm(DataType::Bool(), pa->value != pb->value);
398+
if (fa && fb) return IntImm(DataType::Bool(), fa->value != fb->value);
399399
});
400400
return std::nullopt;
401401
}
@@ -426,7 +426,7 @@ template <>
426426
inline ffi::Optional<PrimExpr> TryConstFold<tir::Not>(PrimExpr a) {
427427
const IntImmNode* pa = a.as<IntImmNode>();
428428
if (pa) {
429-
return IntImm(DataType::UInt(1), !(pa->value));
429+
return IntImm(DataType::Bool(), !(pa->value));
430430
}
431431
return std::nullopt;
432432
}

src/arith/const_int_bound.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -735,9 +735,12 @@ class ConstIntBoundAnalyzer::Impl
735735
* \return Bound that represent everything dtype can represent.
736736
*/
737737
static Entry Everything(DataType dtype) {
738-
if (!dtype.is_int() && !dtype.is_uint()) {
738+
if (!dtype.is_int() && !dtype.is_uint() && !dtype.is_bool()) {
739739
return MakeBound(kNegInf, kPosInf);
740740
}
741+
if (dtype.is_bool()) {
742+
return MakeBound(0, 1);
743+
}
741744
Entry ret;
742745
int64_t vbits = dtype.bits() - static_cast<int>(dtype.is_int());
743746
if (dtype.is_uint()) {

src/ir/expr.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,16 +53,17 @@ PrimExpr PrimExpr::ConvertFallbackValue(ffi::String value) { return tir::StringI
5353
IntImm::IntImm(DataType dtype, int64_t value, Span span) {
5454
ICHECK(dtype.is_scalar()) << "ValueError: IntImm can only take scalar, but " << dtype
5555
<< " was supplied.";
56-
ICHECK(dtype.is_int() || dtype.is_uint())
57-
<< "ValueError: IntImm supports only int or uint type, but " << dtype << " was supplied.";
56+
ICHECK(dtype.is_int() || dtype.is_uint() || dtype.is_bool())
57+
<< "ValueError: IntImm supports only int or uint or bool type, but " << dtype
58+
<< " was supplied.";
5859
if (dtype.is_uint()) {
5960
ICHECK_GE(value, 0U) << "ValueError: Literal value " << value
6061
<< " is negative for unsigned integer type " << dtype;
6162
if (dtype.bits() < 64) {
6263
ICHECK_LT(value, 1LL << dtype.bits())
6364
<< "ValueError: Literal value " << value << " exceeds maximum of " << dtype;
6465
}
65-
} else if (dtype.bits() == 1) {
66+
} else if (dtype.bits() == 1 || dtype.is_bool()) {
6667
// int(1)
6768
ICHECK(value == 0 || value == 1) << "ValueError: " << value << " exceeds range of " << dtype;
6869
} else if (dtype.bits() < 64) {

src/relax/transform/utils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@ inline Constant MakeConstantScalar(T value, DataType dtype) {
328328
*static_cast<int32_t*>(arr->data) = static_cast<int32_t>(value);
329329
} else if (dtype == DataType::Int(64)) {
330330
*static_cast<int64_t*>(arr->data) = static_cast<int64_t>(value);
331-
} else if (dtype == DataType::UInt(1)) {
331+
} else if (dtype == DataType::Bool()) {
332332
*static_cast<bool*>(arr->data) = static_cast<bool>(value);
333333
} else if (dtype == DataType::UInt(8)) {
334334
*static_cast<uint8_t*>(arr->data) = static_cast<uint8_t>(value);

src/runtime/vm/builtin.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -535,7 +535,7 @@ bool ReadIfCond(ffi::AnyView cond) {
535535
if (arr->device.device_type != kDLCPU) {
536536
arr = arr.CopyTo(DLDevice{kDLCPU, 0});
537537
}
538-
ICHECK(arr->dtype.code == kDLInt || arr->dtype.code == kDLUInt);
538+
ICHECK(arr->dtype.code == kDLInt || arr->dtype.code == kDLUInt || arr->dtype.code == kDLBool);
539539
int64_t result;
540540
switch (arr->dtype.bits) {
541541
case 1: {

0 commit comments

Comments
 (0)