Skip to content

fix(tir): convert bool dtype to uint1 in BlockRealize predicates#341

Open
cipher982 wants to merge 5 commits intomlc-ai:mlcfrom
cipher982:fix/tir-bool-predicate
Open

fix(tir): convert bool dtype to uint1 in BlockRealize predicates#341
cipher982 wants to merge 5 commits intomlc-ai:mlcfrom
cipher982:fix/tir-bool-predicate

Conversation

@cipher982
Copy link
Copy Markdown

Problem

TVM's C++ runtime rejects "bool" as a dtype for IntImm expressions:

CHECK(dtype.is_int() || dtype.is_uint())
    << "cannot make const for type " << dtype;

However, tir.BlockRealize.__init__ accepts predicate: Union[PrimExpr, bool] per the documented API, and various code paths create predicates with dtype "bool":

  • Python bool literals: True/False
  • Legacy API: tir.const(True, "bool")
  • Direct creation: tir.IntImm("bool", 1)
  • Dynamic expressions: tir.EQ(i, j), tir.Not(flag), etc.

This causes runtime errors during compilation, particularly for WebGPU/WASM targets in MLC-LLM.

Solution

Modified BlockRealize.__init__ in python/tvm/tir/stmt.py to convert any "bool" dtype predicate to TVM's canonical boolean representation (uint1):

  1. Constant predicates: Extract value and create IntImm("uint1", value)
  2. Dynamic predicates: Wrap in Cast("uint1", predicate) to preserve logic

Testing

Added comprehensive regression tests in tests/python/tir-base/test_tir_block_realize_predicate.py:

  • Parameterized tests for all constant predicate forms
  • Parameterized tests for dynamic predicates (EQ, Not, Or, And)
  • Verification that dynamic logic is preserved via Cast
  • Real-world PrimFunc context test

tqchen and others added 5 commits November 12, 2025 13:37
MLC local ci setup. Also CI for Windows and macOS building,
which may take 90-100 mins.

Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn>
- Revert "[CMake][MSVC] Disable permissive mode for MSVC builds (#16343)"
- Skip MSC tests
- Disable NNPack and TFLite
- Tweak CMAKE_CUDA_ARCHITECTURES
TVM C++ runtime rejects 'bool' dtype with:
  CHECK(dtype.is_int() || dtype.is_uint())

Extended fix to handle ALL forms of bool predicates:
- Python bool literals (True/False) → IntImm("uint1", value)
- Constant expressions (const, IntImm) → IntImm("uint1", value)
- Dynamic expressions (EQ, Not, Or, And) → Cast("uint1", predicate)

Critical: Dynamic predicates MUST be cast, not replaced with constant
True. This preserves runtime logic for guarded blocks that depend on
loop indices or runtime conditions.

Implementation note:
str(predicate.dtype) is used because TVM's dtype attribute returns a
DataType object rather than a string.

Fixes compilation failures for WebGPU/WASM targets in MLC-LLM when
BlockRealize is constructed with boolean predicates.

Testing:
- Comprehensive regression tests with pytest parameterization
- Tests cover Python bool, tir.const, tir.IntImm (constants)
- Tests cover EQ, Not, Or, And (dynamic predicates)
- Verifies dynamic predicates are cast, not replaced
@cipher982 cipher982 force-pushed the fix/tir-bool-predicate branch from 16a61fc to 033e0cb Compare December 5, 2025 17:14
@cipher982
Copy link
Copy Markdown
Author

actually I see #342 addresses this at the c++ level in IntImm/MakeConstScalar. That's probably cleaner! mine was a workaround at the python boundary. We can close this if #342 is the preferred approach.

Copy link
Copy Markdown

@Ihorog Ihorog left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ок

Copy link
Copy Markdown

@Ihorog Ihorog left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ок

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants