Skip to content

Commit 0eff3fa

Browse files
authored
Fix the legalization func of matmul to avoid zero index (#48)
This PR fixes a known issue of the TE legalization func of matmul, which is used here in the generation of NT matmul. With this PR, there will be no index 0 whenever the matmul on the particular dimension is not broadcasting.
1 parent 89bbaa4 commit 0eff3fa

File tree

2 files changed

+21
-23
lines changed

2 files changed

+21
-23
lines changed

web_llm/transform/dispatch_tir_operator.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -266,11 +266,11 @@ def matmul1_before(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, m
266266
for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), T.int64(128), n):
267267
with T.block("matmul"):
268268
v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k])
269-
T.reads(rxplaceholder[T.int64(0), v_i1, v_i2, v_k], rxplaceholder_1[T.int64(0), v_i1, v_k, v_i3])
269+
T.reads(rxplaceholder[v_i0, v_i1, v_i2, v_k], rxplaceholder_1[v_i0, v_i1, v_k, v_i3])
270270
T.writes(matmul[v_i0, v_i1, v_i2, v_i3])
271271
with T.init():
272272
matmul[v_i0, v_i1, v_i2, v_i3] = T.float32(0)
273-
matmul[v_i0, v_i1, v_i2, v_i3] = matmul[v_i0, v_i1, v_i2, v_i3] + rxplaceholder[T.int64(0), v_i1, v_i2, v_k] * rxplaceholder_1[T.int64(0), v_i1, v_k, v_i3]
273+
matmul[v_i0, v_i1, v_i2, v_i3] = matmul[v_i0, v_i1, v_i2, v_i3] + rxplaceholder[v_i0, v_i1, v_i2, v_k] * rxplaceholder_1[v_i0, v_i1, v_k, v_i3]
274274

275275

276276
@T.prim_func
@@ -448,11 +448,11 @@ def matmul5_before(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, v
448448
for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), n, T.int64(128), n):
449449
with T.block("matmul"):
450450
v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k])
451-
T.reads(rxplaceholder[T.int64(0), v_i1, v_i2, v_k], rxplaceholder_1[T.int64(0), v_i1, v_k, v_i3])
451+
T.reads(rxplaceholder[v_i0, v_i1, v_i2, v_k], rxplaceholder_1[v_i0, v_i1, v_k, v_i3])
452452
T.writes(matmul_1[v_i0, v_i1, v_i2, v_i3])
453453
with T.init():
454454
matmul_1[v_i0, v_i1, v_i2, v_i3] = T.float32(0)
455-
matmul_1[v_i0, v_i1, v_i2, v_i3] = matmul_1[v_i0, v_i1, v_i2, v_i3] + rxplaceholder[T.int64(0), v_i1, v_i2, v_k] * rxplaceholder_1[T.int64(0), v_i1, v_k, v_i3]
455+
matmul_1[v_i0, v_i1, v_i2, v_i3] = matmul_1[v_i0, v_i1, v_i2, v_i3] + rxplaceholder[v_i0, v_i1, v_i2, v_k] * rxplaceholder_1[v_i0, v_i1, v_k, v_i3]
456456

457457

458458
@T.prim_func
@@ -1363,11 +1363,11 @@ def fused_NT_matmul1_divide_add_maximum_before(p_lv28: T.handle, p_lv29: T.handl
13631363
for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), n, n, T.int64(128)):
13641364
with T.block("NT_matmul"):
13651365
v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k])
1366-
T.reads(lv28[T.int64(0), v_i1, v_i2, v_k], lv29[T.int64(0), v_i1, v_i3, v_k])
1366+
T.reads(lv28[v_i0, v_i1, v_i2, v_k], lv29[v_i0, v_i1, v_i3, v_k])
13671367
T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3])
13681368
with T.init():
13691369
var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = T.float32(0)
1370-
var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] + lv28[T.int64(0), v_i1, v_i2, v_k] * lv29[T.int64(0), v_i1, v_i3, v_k]
1370+
var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] + lv28[v_i0, v_i1, v_i2, v_k] * lv29[v_i0, v_i1, v_i3, v_k]
13711371
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, n):
13721372
with T.block("T_divide"):
13731373
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
@@ -1479,11 +1479,11 @@ def fused_NT_matmul6_divide1_add2_maximum1_before(lv2732: T.Buffer((T.int64(1),
14791479
for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n, T.int64(128)):
14801480
with T.block("NT_matmul"):
14811481
v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k])
1482-
T.reads(lv2732[T.int64(0), v_i1, v_i2, v_k], lv2733[T.int64(0), v_i1, v_i3, v_k])
1482+
T.reads(lv2732[v_i0, v_i1, v_i2, v_k], lv2733[v_i0, v_i1, v_i3, v_k])
14831483
T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3])
14841484
with T.init():
14851485
var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = T.float32(0)
1486-
var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] + lv2732[T.int64(0), v_i1, v_i2, v_k] * lv2733[T.int64(0), v_i1, v_i3, v_k]
1486+
var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] + lv2732[v_i0, v_i1, v_i2, v_k] * lv2733[v_i0, v_i1, v_i3, v_k]
14871487
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n):
14881488
with T.block("T_divide"):
14891489
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])

web_llm/transform/transpose_matmul.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,7 @@ def te_transposed_matmul(a: te.Tensor, b: te.Tensor) -> te.Tensor:
4444
b_shape.append(1)
4545

4646
is_a_larger = len(a_shape) > len(b_shape)
47-
offset = (
48-
len(a_shape) - len(b_shape)
49-
if is_a_larger
50-
else len(b_shape) - len(a_shape)
51-
)
47+
offset = len(a_shape) - len(b_shape) if is_a_larger else len(b_shape) - len(a_shape)
5248

5349
a_relax = relax.Var("a", relax.TensorStructInfo(a.shape))
5450
bT_shape = list(b.shape)
@@ -70,15 +66,19 @@ def multiply_compute(idx_reduce):
7066
a_indices.append(idx_spatial[i])
7167
else:
7268
b_indices.append(idx_spatial[i])
73-
for i in range(
74-
offset, len(output_shape) - (2 - a_prepended - b_appended)
75-
):
69+
for i in range(offset, len(output_shape) - (2 - a_prepended - b_appended)):
7670
a_dim = a_shape[i if is_a_larger else i - offset]
7771
b_dim = b_shape[i if not is_a_larger else i - offset]
78-
a_dim_is_one = isinstance(a_dim, tir.IntImm) and a_dim == 1
79-
b_dim_is_one = isinstance(b_dim, tir.IntImm) and b_dim == 1
80-
a_indices.append(0 if a_dim_is_one else idx_spatial[i])
81-
b_indices.append(0 if b_dim_is_one else idx_spatial[i])
72+
dim_equal = a_dim == b_dim
73+
if not isinstance(dim_equal, tir.IntImm) or dim_equal == 0:
74+
a_dim_is_one = isinstance(a_dim, tir.IntImm) and a_dim == 1
75+
b_dim_is_one = isinstance(b_dim, tir.IntImm) and b_dim == 1
76+
a_indices.append(0 if a_dim_is_one else idx_spatial[i])
77+
b_indices.append(0 if b_dim_is_one else idx_spatial[i])
78+
else:
79+
a_indices.append(idx_spatial[i])
80+
b_indices.append(idx_spatial[i])
81+
8282
if not a_prepended:
8383
a_indices.append(idx_spatial[-2 + b_appended])
8484
a_indices.append(idx_reduce)
@@ -118,9 +118,7 @@ def multiply_compute(idx_reduce):
118118

119119
@tvm.transform.module_pass(opt_level=0, name="FuseTransposeMatmul")
120120
class FuseTransposeMatmul:
121-
def transform_module(
122-
self, mod: IRModule, ctx: tvm.transform.PassContext
123-
) -> IRModule:
121+
def transform_module(self, mod: IRModule, ctx: tvm.transform.PassContext) -> IRModule:
124122
mod = relax.transform.FuseOpsByPattern(
125123
[("transpose_matmul_fuse", *TransposeMatmulCodeGenerator.pattern())]
126124
)(mod)

0 commit comments

Comments
 (0)