@@ -266,11 +266,11 @@ def matmul1_before(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, m
266
266
for i0 , i1 , i2 , i3 , k in T .grid (T .int64 (1 ), T .int64 (32 ), T .int64 (1 ), T .int64 (128 ), n ):
267
267
with T .block ("matmul" ):
268
268
v_i0 , v_i1 , v_i2 , v_i3 , v_k = T .axis .remap ("SSSSR" , [i0 , i1 , i2 , i3 , k ])
269
- T .reads (rxplaceholder [T . int64 ( 0 ) , v_i1 , v_i2 , v_k ], rxplaceholder_1 [T . int64 ( 0 ) , v_i1 , v_k , v_i3 ])
269
+ T .reads (rxplaceholder [v_i0 , v_i1 , v_i2 , v_k ], rxplaceholder_1 [v_i0 , v_i1 , v_k , v_i3 ])
270
270
T .writes (matmul [v_i0 , v_i1 , v_i2 , v_i3 ])
271
271
with T .init ():
272
272
matmul [v_i0 , v_i1 , v_i2 , v_i3 ] = T .float32 (0 )
273
- matmul [v_i0 , v_i1 , v_i2 , v_i3 ] = matmul [v_i0 , v_i1 , v_i2 , v_i3 ] + rxplaceholder [T . int64 ( 0 ) , v_i1 , v_i2 , v_k ] * rxplaceholder_1 [T . int64 ( 0 ) , v_i1 , v_k , v_i3 ]
273
+ matmul [v_i0 , v_i1 , v_i2 , v_i3 ] = matmul [v_i0 , v_i1 , v_i2 , v_i3 ] + rxplaceholder [v_i0 , v_i1 , v_i2 , v_k ] * rxplaceholder_1 [v_i0 , v_i1 , v_k , v_i3 ]
274
274
275
275
276
276
@T .prim_func
@@ -448,11 +448,11 @@ def matmul5_before(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, v
448
448
for i0 , i1 , i2 , i3 , k in T .grid (T .int64 (1 ), T .int64 (32 ), n , T .int64 (128 ), n ):
449
449
with T .block ("matmul" ):
450
450
v_i0 , v_i1 , v_i2 , v_i3 , v_k = T .axis .remap ("SSSSR" , [i0 , i1 , i2 , i3 , k ])
451
- T .reads (rxplaceholder [T . int64 ( 0 ) , v_i1 , v_i2 , v_k ], rxplaceholder_1 [T . int64 ( 0 ) , v_i1 , v_k , v_i3 ])
451
+ T .reads (rxplaceholder [v_i0 , v_i1 , v_i2 , v_k ], rxplaceholder_1 [v_i0 , v_i1 , v_k , v_i3 ])
452
452
T .writes (matmul_1 [v_i0 , v_i1 , v_i2 , v_i3 ])
453
453
with T .init ():
454
454
matmul_1 [v_i0 , v_i1 , v_i2 , v_i3 ] = T .float32 (0 )
455
- matmul_1 [v_i0 , v_i1 , v_i2 , v_i3 ] = matmul_1 [v_i0 , v_i1 , v_i2 , v_i3 ] + rxplaceholder [T . int64 ( 0 ) , v_i1 , v_i2 , v_k ] * rxplaceholder_1 [T . int64 ( 0 ) , v_i1 , v_k , v_i3 ]
455
+ matmul_1 [v_i0 , v_i1 , v_i2 , v_i3 ] = matmul_1 [v_i0 , v_i1 , v_i2 , v_i3 ] + rxplaceholder [v_i0 , v_i1 , v_i2 , v_k ] * rxplaceholder_1 [v_i0 , v_i1 , v_k , v_i3 ]
456
456
457
457
458
458
@T .prim_func
@@ -1363,11 +1363,11 @@ def fused_NT_matmul1_divide_add_maximum_before(p_lv28: T.handle, p_lv29: T.handl
1363
1363
for i0 , i1 , i2 , i3 , k in T .grid (T .int64 (1 ), T .int64 (32 ), n , n , T .int64 (128 )):
1364
1364
with T .block ("NT_matmul" ):
1365
1365
v_i0 , v_i1 , v_i2 , v_i3 , v_k = T .axis .remap ("SSSSR" , [i0 , i1 , i2 , i3 , k ])
1366
- T .reads (lv28 [T . int64 ( 0 ) , v_i1 , v_i2 , v_k ], lv29 [T . int64 ( 0 ) , v_i1 , v_i3 , v_k ])
1366
+ T .reads (lv28 [v_i0 , v_i1 , v_i2 , v_k ], lv29 [v_i0 , v_i1 , v_i3 , v_k ])
1367
1367
T .writes (var_NT_matmul_intermediate [v_i0 , v_i1 , v_i2 , v_i3 ])
1368
1368
with T .init ():
1369
1369
var_NT_matmul_intermediate [v_i0 , v_i1 , v_i2 , v_i3 ] = T .float32 (0 )
1370
- var_NT_matmul_intermediate [v_i0 , v_i1 , v_i2 , v_i3 ] = var_NT_matmul_intermediate [v_i0 , v_i1 , v_i2 , v_i3 ] + lv28 [T . int64 ( 0 ) , v_i1 , v_i2 , v_k ] * lv29 [T . int64 ( 0 ) , v_i1 , v_i3 , v_k ]
1370
+ var_NT_matmul_intermediate [v_i0 , v_i1 , v_i2 , v_i3 ] = var_NT_matmul_intermediate [v_i0 , v_i1 , v_i2 , v_i3 ] + lv28 [v_i0 , v_i1 , v_i2 , v_k ] * lv29 [v_i0 , v_i1 , v_i3 , v_k ]
1371
1371
for ax0 , ax1 , ax2 , ax3 in T .grid (T .int64 (1 ), T .int64 (32 ), n , n ):
1372
1372
with T .block ("T_divide" ):
1373
1373
v_ax0 , v_ax1 , v_ax2 , v_ax3 = T .axis .remap ("SSSS" , [ax0 , ax1 , ax2 , ax3 ])
@@ -1479,11 +1479,11 @@ def fused_NT_matmul6_divide1_add2_maximum1_before(lv2732: T.Buffer((T.int64(1),
1479
1479
for i0 , i1 , i2 , i3 , k in T .grid (T .int64 (1 ), T .int64 (32 ), T .int64 (1 ), n , T .int64 (128 )):
1480
1480
with T .block ("NT_matmul" ):
1481
1481
v_i0 , v_i1 , v_i2 , v_i3 , v_k = T .axis .remap ("SSSSR" , [i0 , i1 , i2 , i3 , k ])
1482
- T .reads (lv2732 [T . int64 ( 0 ) , v_i1 , v_i2 , v_k ], lv2733 [T . int64 ( 0 ) , v_i1 , v_i3 , v_k ])
1482
+ T .reads (lv2732 [v_i0 , v_i1 , v_i2 , v_k ], lv2733 [v_i0 , v_i1 , v_i3 , v_k ])
1483
1483
T .writes (var_NT_matmul_intermediate [v_i0 , v_i1 , v_i2 , v_i3 ])
1484
1484
with T .init ():
1485
1485
var_NT_matmul_intermediate [v_i0 , v_i1 , v_i2 , v_i3 ] = T .float32 (0 )
1486
- var_NT_matmul_intermediate [v_i0 , v_i1 , v_i2 , v_i3 ] = var_NT_matmul_intermediate [v_i0 , v_i1 , v_i2 , v_i3 ] + lv2732 [T . int64 ( 0 ) , v_i1 , v_i2 , v_k ] * lv2733 [T . int64 ( 0 ) , v_i1 , v_i3 , v_k ]
1486
+ var_NT_matmul_intermediate [v_i0 , v_i1 , v_i2 , v_i3 ] = var_NT_matmul_intermediate [v_i0 , v_i1 , v_i2 , v_i3 ] + lv2732 [v_i0 , v_i1 , v_i2 , v_k ] * lv2733 [v_i0 , v_i1 , v_i3 , v_k ]
1487
1487
for ax0 , ax1 , ax2 , ax3 in T .grid (T .int64 (1 ), T .int64 (32 ), T .int64 (1 ), n ):
1488
1488
with T .block ("T_divide" ):
1489
1489
v_ax0 , v_ax1 , v_ax2 , v_ax3 = T .axis .remap ("SSSS" , [ax0 , ax1 , ax2 , ax3 ])
0 commit comments