Skip to content

Commit c3a52ea

Browse files
authored
[TIR] Update function signatures for decompose_reduction (#18505)
## Related Issue closes #18215
1 parent 6e0d4d5 commit c3a52ea

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

python/tvm/tir/schedule/schedule.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2411,7 +2411,7 @@ def decompose_reduction(self, block: Union[BlockRV, str], loop: LoopRV) -> Block
24112411
.. code-block:: python
24122412
24132413
@T.prim_func
2414-
def before_decompose(a: ty.handle, c: ty.handle) -> None:
2414+
def before_decompose(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
24152415
A = tir.match_buffer(a, [128, 128])
24162416
B = tir.match_buffer(b, [128, 128])
24172417
C = tir.match_buffer(c, [128, 128])
@@ -2436,7 +2436,7 @@ def before_decompose(a: ty.handle, c: ty.handle) -> None:
24362436
.. code-block:: python
24372437
24382438
@T.prim_func
2439-
def after_decompose(a: ty.handle, c: ty.handle) -> None:
2439+
def after_decompose(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
24402440
A = tir.match_buffer(a, [128, 128])
24412441
B = tir.match_buffer(b, [128, 128])
24422442
C = tir.match_buffer(c, [128, 128])

0 commit comments

Comments
 (0)