@@ -1973,28 +1973,28 @@ def decode_sch_func(orig_func):
1973
1973
1974
1974
@T .prim_func
1975
1975
def fused_decode3_matmul1_before (lv2931 : T .Buffer ((T .int64 (512 ), T .int64 (32001 )), "uint32" ), lv2932 : T .Buffer ((T .int64 (128 ), T .int64 (32001 )), "uint32" ), lv1511 : T .Buffer ((T .int64 (1 ), T .int64 (1 ), T .int64 (4096 )), "float32" ), var_matmul_intermediate : T .Buffer ((T .int64 (1 ), T .int64 (1 ), T .int64 (32001 )), "float32" )):
1976
- T .func_attr ({"tir.noalias" : T .bool (True )})
1977
- # with T.block("root"):
1978
- var_decode_intermediate = T .alloc_buffer ((T .int64 (4096 ), T .int64 (32001 )))
1979
- for i , j in T .grid (T .int64 (4096 ), T .int64 (32001 )):
1980
- with T .block ("decode" ):
1981
- v_i , v_j = T .axis .remap ("SS" , [i , j ])
1982
- T .reads (lv2931 [v_i // T .int64 (8 ), v_j ], lv2932 [v_i // T .int64 (32 ), v_j ])
1983
- T .writes (var_decode_intermediate [v_i , v_j ])
1984
- var_decode_intermediate [v_i , v_j ] = T .Cast ("float32" , T .bitwise_and (T .shift_right (lv2931 [v_i // T .int64 (8 ), v_j ], T .Cast ("uint32" , v_i % T .int64 (8 ) * T .int64 (4 ))), T .uint32 (15 ))) * T .reinterpret ("float32" , T .shift_left (T .bitwise_and (lv2932 [v_i // T .int64 (32 ), v_j ], T .uint32 (65535 )), T .uint32 (16 ))) + T .reinterpret ("float32" , T .shift_left (T .bitwise_and (T .shift_right (lv2932 [v_i // T .int64 (32 ), v_j ], T .uint32 (16 )), T .uint32 (65535 )), T .uint32 (16 )))
1985
- for i0 , i1 , i2 , k in T .grid (T .int64 (1 ), T .int64 (1 ), T .int64 (32001 ), T .int64 (4096 )):
1986
- with T .block ("matmul" ):
1987
- v_i0 , v_i1 , v_i2 , v_k = T .axis .remap ("SSSR" , [i0 , i1 , i2 , k ])
1988
- T .reads (lv1511 [v_i0 , v_i1 , v_k ], var_decode_intermediate [v_k , v_i2 ])
1989
- T .writes (var_matmul_intermediate [v_i0 , v_i1 , v_i2 ])
1990
- with T .init ():
1991
- var_matmul_intermediate [v_i0 , v_i1 , v_i2 ] = T .float32 (0 )
1992
- var_matmul_intermediate [v_i0 , v_i1 , v_i2 ] = var_matmul_intermediate [v_i0 , v_i1 , v_i2 ] + lv1511 [v_i0 , v_i1 , v_k ] * var_decode_intermediate [v_k , v_i2 ]
1976
+ T .func_attr ({"tir.noalias" : T .bool (True )})
1977
+ # with T.block("root"):
1978
+ var_decode_intermediate = T .alloc_buffer ((T .int64 (4096 ), T .int64 (32001 )))
1979
+ for i , j in T .grid (T .int64 (4096 ), T .int64 (32001 )):
1980
+ with T .block ("decode" ):
1981
+ v_i , v_j = T .axis .remap ("SS" , [i , j ])
1982
+ T .reads (lv2931 [v_i // T .int64 (8 ), v_j ], lv2932 [v_i // T .int64 (32 ), v_j ])
1983
+ T .writes (var_decode_intermediate [v_i , v_j ])
1984
+ var_decode_intermediate [v_i , v_j ] = T .Cast ("float32" , T .bitwise_and (T .shift_right (lv2931 [v_i // T .int64 (8 ), v_j ], T .Cast ("uint32" , v_i % T .int64 (8 ) * T .int64 (4 ))), T .uint32 (15 ))) * T .reinterpret ("float32" , T .shift_left (T .bitwise_and (lv2932 [v_i // T .int64 (32 ), v_j ], T .uint32 (65535 )), T .uint32 (16 ))) + T .reinterpret ("float32" , T .shift_left (T .bitwise_and (T .shift_right (lv2932 [v_i // T .int64 (32 ), v_j ], T .uint32 (16 )), T .uint32 (65535 )), T .uint32 (16 )))
1985
+ for i0 , i1 , i2 , k in T .grid (T .int64 (1 ), T .int64 (1 ), T .int64 (32001 ), T .int64 (4096 )):
1986
+ with T .block ("matmul" ):
1987
+ v_i0 , v_i1 , v_i2 , v_k = T .axis .remap ("SSSR" , [i0 , i1 , i2 , k ])
1988
+ T .reads (lv1511 [v_i0 , v_i1 , v_k ], var_decode_intermediate [v_k , v_i2 ])
1989
+ T .writes (var_matmul_intermediate [v_i0 , v_i1 , v_i2 ])
1990
+ with T .init ():
1991
+ var_matmul_intermediate [v_i0 , v_i1 , v_i2 ] = T .float32 (0 )
1992
+ var_matmul_intermediate [v_i0 , v_i1 , v_i2 ] = var_matmul_intermediate [v_i0 , v_i1 , v_i2 ] + lv1511 [v_i0 , v_i1 , v_k ] * var_decode_intermediate [v_k , v_i2 ]
1993
1993
1994
1994
1995
1995
@T .prim_func
1996
1996
def fused_decode3_matmul1_after (lv1123 : T .Buffer ((T .int64 (512 ), T .int64 (32001 )), "uint32" ), lv1124 : T .Buffer ((T .int64 (128 ), T .int64 (32001 )), "uint32" ), lv1511 : T .Buffer ((T .int64 (1 ), T .int64 (1 ), T .int64 (4096 )), "float32" ), var_matmul_intermediate : T .Buffer ((T .int64 (1 ), T .int64 (1 ), T .int64 (32001 )), "float32" )):
1997
- T .func_attr ({"global_symbol" : "main" , "tir.noalias" : T .bool (True )})
1997
+ T .func_attr ({"global_symbol" : "main" , "tir.noalias" : T .bool (True ), "tir.is_scheduled" : 1 })
1998
1998
# with T.block("root"):
1999
1999
var_decode_intermediate_pad_local = T .alloc_buffer ((T .int64 (4096 ), T .int64 (35072 )), scope = "local" )
2000
2000
var_matmul_intermediate_pad_local = T .alloc_buffer ((T .int64 (1 ), T .int64 (1 ), T .int64 (35072 )), scope = "local" )
@@ -2415,59 +2415,59 @@ def fused_decode6_fused_matmul9_add3_before(lv1623: T.Buffer((T.int64(1376), T.i
2415
2415
2416
2416
@T .prim_func
2417
2417
def fused_decode6_fused_matmul9_add3_after (lv1158 : T .Buffer ((T .int64 (1376 ), T .int64 (4096 )), "uint32" ), lv1159 : T .Buffer ((T .int64 (344 ), T .int64 (4096 )), "uint32" ), lv6 : T .Buffer ((T .int64 (1 ), T .int64 (1 ), T .int64 (11008 )), "float32" ), lv4 : T .Buffer ((T .int64 (1 ), T .int64 (1 ), T .int64 (4096 )), "float32" ), p_output0_intermediate : T .Buffer ((T .int64 (1 ), T .int64 (1 ), T .int64 (4096 )), "float32" )):
2418
- T .func_attr ({"global_symbol" : "main" , "tir.noalias" : T .bool (True )})
2419
- # with T.block("root"):
2420
- var_decode_intermediate_local = T .alloc_buffer ((T .int64 (11008 ), T .int64 (4096 )), scope = "local" )
2421
- var_matmul_intermediate_local = T .alloc_buffer ((T .int64 (1 ), T .int64 (1 ), T .int64 (4096 )), scope = "local" )
2422
- lv6_shared = T .alloc_buffer ((T .int64 (1 ), T .int64 (1 ), T .int64 (11008 )), scope = "shared" )
2423
- for i0_i1_i2_0_fused in T .thread_binding (T .int64 (16 ), thread = "blockIdx.x" , annotations = {"pragma_auto_unroll_max_step" : 16 , "pragma_unroll_explicit" : 1 }):
2424
- for i2_1 in T .thread_binding (T .int64 (1 ), thread = "vthread.x" ):
2425
- for i2_2 in T .thread_binding (T .int64 (256 ), thread = "threadIdx.x" ):
2426
- with T .block ("matmul_init" ):
2427
- v_i0 = T .axis .spatial (T .int64 (1 ), T .int64 (0 ))
2428
- v_i1 = T .axis .spatial (T .int64 (1 ), T .int64 (0 ))
2429
- v_i2 = T .axis .spatial (T .int64 (4096 ), i0_i1_i2_0_fused * T .int64 (256 ) + i2_1 * T .int64 (256 ) + i2_2 )
2430
- T .reads ()
2431
- T .writes (var_matmul_intermediate_local [v_i0 , v_i1 , v_i2 ])
2432
- var_matmul_intermediate_local [v_i0 , v_i1 , v_i2 ] = T .float32 (0 )
2433
- for k_0_0 in range (T .int64 (2 )):
2434
- for ax0 , ax1_ax2_fused_0 in T .grid (T .int64 (1 ), T .int64 (22 )):
2435
- for ax1_ax2_fused_1 in T .thread_binding (T .int64 (256 ), thread = "threadIdx.x" ):
2436
- with T .block ("lv6_shared" ):
2437
- v0 = T .axis .spatial (T .int64 (1 ), ax0 )
2438
- v1 = T .axis .spatial (T .int64 (1 ), T .int64 (0 ))
2439
- v2 = T .axis .spatial (T .int64 (11008 ), k_0_0 * T .int64 (5504 ) + (ax1_ax2_fused_0 * T .int64 (256 ) + ax1_ax2_fused_1 ))
2440
- T .where (ax1_ax2_fused_0 * T .int64 (256 ) + ax1_ax2_fused_1 < T .int64 (5504 ))
2441
- T .reads (lv6 [v0 , v1 , v2 ])
2442
- T .writes (lv6_shared [v0 , v1 , v2 ])
2443
- T .block_attr ({"buffer_dim_align" : [[0 , 1 , 32 , 8 ]]})
2444
- lv6_shared [v0 , v1 , v2 ] = lv6 [v0 , v1 , v2 ]
2445
- for k_0_1 in range (T .int64 (86 )):
2446
- for ax0_0 in range (T .int64 (8 )):
2447
- for ax0_1 in T .unroll (T .int64 (8 )):
2448
- for ax1 in range (T .int64 (1 )):
2449
- with T .block ("decode" ):
2450
- v_j = T .axis .spatial (T .int64 (11008 ), k_0_0 * T .int64 (5504 ) + k_0_1 * T .int64 (64 ) + ax0_0 * T .int64 (8 ) + ax0_1 )
2451
- v_i = T .axis .spatial (T .int64 (4096 ), i0_i1_i2_0_fused * T .int64 (256 ) + i2_2 + ax1 )
2452
- T .reads (lv1158 [v_j // T .int64 (8 ), v_i ], lv1159 [v_j // T .int64 (32 ), v_i ])
2453
- T .writes (var_decode_intermediate_local [v_j , v_i ])
2454
- var_decode_intermediate_local [v_j , v_i ] = T .Cast ("float32" , T .bitwise_and (T .shift_right (lv1158 [v_j // T .int64 (8 ), v_i ], T .Cast ("uint32" , v_j % T .int64 (8 ) * T .int64 (4 ))), T .uint32 (15 ))) * T .reinterpret ("float32" , T .shift_left (T .bitwise_and (lv1159 [v_j // T .int64 (32 ), v_i ], T .uint32 (65535 )), T .uint32 (16 ))) + T .reinterpret ("float32" , T .shift_left (T .bitwise_and (T .shift_right (lv1159 [v_j // T .int64 (32 ), v_i ], T .uint32 (16 )), T .uint32 (65535 )), T .uint32 (16 )))
2455
- for k_0_2_k_1_fused in range (T .int64 (64 )):
2456
- with T .block ("matmul_update" ):
2457
- v_i0 = T .axis .spatial (T .int64 (1 ), T .int64 (0 ))
2458
- v_i1 = T .axis .spatial (T .int64 (1 ), T .int64 (0 ))
2459
- v_i2 = T .axis .spatial (T .int64 (4096 ), i0_i1_i2_0_fused * T .int64 (256 ) + i2_1 * T .int64 (256 ) + i2_2 )
2460
- v_k = T .axis .reduce (T .int64 (11008 ), k_0_0 * T .int64 (5504 ) + k_0_1 * T .int64 (64 ) + k_0_2_k_1_fused )
2461
- T .reads (var_matmul_intermediate_local [v_i0 , v_i1 , v_i2 ], lv6_shared [v_i0 , v_i1 , v_k ], var_decode_intermediate_local [v_k , v_i2 ])
2462
- T .writes (var_matmul_intermediate_local [v_i0 , v_i1 , v_i2 ])
2463
- var_matmul_intermediate_local [v_i0 , v_i1 , v_i2 ] = var_matmul_intermediate_local [v_i0 , v_i1 , v_i2 ] + lv6_shared [v_i0 , v_i1 , v_k ] * var_decode_intermediate_local [v_k , v_i2 ]
2464
- for ax0 , ax1 , ax2 in T .grid (T .int64 (1 ), T .int64 (1 ), T .int64 (1 )):
2465
- with T .block ("var_matmul_intermediate_local" ):
2466
- v0 , v1 = T .axis .remap ("SS" , [ax0 , ax1 ])
2467
- v2 = T .axis .spatial (T .int64 (4096 ), i0_i1_i2_0_fused * T .int64 (256 ) + i2_2 + ax2 )
2468
- T .reads (lv4 [v0 , v1 , v2 ], var_matmul_intermediate_local [v0 , v1 , v2 ])
2469
- T .writes (p_output0_intermediate [v0 , v1 , v2 ])
2470
- p_output0_intermediate [v0 , v1 , v2 ] = lv4 [v0 , v1 , v2 ] + var_matmul_intermediate_local [v0 , v1 , v2 ]
2418
+ T .func_attr ({"global_symbol" : "main" , "tir.noalias" : T .bool (True ), "tir.is_scheduled" : 1 })
2419
+ # with T.block("root"):
2420
+ var_decode_intermediate_local = T .alloc_buffer ((T .int64 (11008 ), T .int64 (4096 )), scope = "local" )
2421
+ var_matmul_intermediate_local = T .alloc_buffer ((T .int64 (1 ), T .int64 (1 ), T .int64 (4096 )), scope = "local" )
2422
+ lv6_shared = T .alloc_buffer ((T .int64 (1 ), T .int64 (1 ), T .int64 (11008 )), scope = "shared" )
2423
+ for i0_i1_i2_0_fused in T .thread_binding (T .int64 (16 ), thread = "blockIdx.x" , annotations = {"pragma_auto_unroll_max_step" : 16 , "pragma_unroll_explicit" : 1 }):
2424
+ for i2_1 in T .thread_binding (T .int64 (1 ), thread = "vthread.x" ):
2425
+ for i2_2 in T .thread_binding (T .int64 (256 ), thread = "threadIdx.x" ):
2426
+ with T .block ("matmul_init" ):
2427
+ v_i0 = T .axis .spatial (T .int64 (1 ), T .int64 (0 ))
2428
+ v_i1 = T .axis .spatial (T .int64 (1 ), T .int64 (0 ))
2429
+ v_i2 = T .axis .spatial (T .int64 (4096 ), i0_i1_i2_0_fused * T .int64 (256 ) + i2_1 * T .int64 (256 ) + i2_2 )
2430
+ T .reads ()
2431
+ T .writes (var_matmul_intermediate_local [v_i0 , v_i1 , v_i2 ])
2432
+ var_matmul_intermediate_local [v_i0 , v_i1 , v_i2 ] = T .float32 (0 )
2433
+ for k_0_0 in range (T .int64 (2 )):
2434
+ for ax0 , ax1_ax2_fused_0 in T .grid (T .int64 (1 ), T .int64 (22 )):
2435
+ for ax1_ax2_fused_1 in T .thread_binding (T .int64 (256 ), thread = "threadIdx.x" ):
2436
+ with T .block ("lv6_shared" ):
2437
+ v0 = T .axis .spatial (T .int64 (1 ), ax0 )
2438
+ v1 = T .axis .spatial (T .int64 (1 ), T .int64 (0 ))
2439
+ v2 = T .axis .spatial (T .int64 (11008 ), k_0_0 * T .int64 (5504 ) + (ax1_ax2_fused_0 * T .int64 (256 ) + ax1_ax2_fused_1 ))
2440
+ T .where (ax1_ax2_fused_0 * T .int64 (256 ) + ax1_ax2_fused_1 < T .int64 (5504 ))
2441
+ T .reads (lv6 [v0 , v1 , v2 ])
2442
+ T .writes (lv6_shared [v0 , v1 , v2 ])
2443
+ T .block_attr ({"buffer_dim_align" : [[0 , 1 , 32 , 8 ]]})
2444
+ lv6_shared [v0 , v1 , v2 ] = lv6 [v0 , v1 , v2 ]
2445
+ for k_0_1 in range (T .int64 (86 )):
2446
+ for ax0_0 in range (T .int64 (8 )):
2447
+ for ax0_1 in T .unroll (T .int64 (8 )):
2448
+ for ax1 in range (T .int64 (1 )):
2449
+ with T .block ("decode" ):
2450
+ v_j = T .axis .spatial (T .int64 (11008 ), k_0_0 * T .int64 (5504 ) + k_0_1 * T .int64 (64 ) + ax0_0 * T .int64 (8 ) + ax0_1 )
2451
+ v_i = T .axis .spatial (T .int64 (4096 ), i0_i1_i2_0_fused * T .int64 (256 ) + i2_2 + ax1 )
2452
+ T .reads (lv1158 [v_j // T .int64 (8 ), v_i ], lv1159 [v_j // T .int64 (32 ), v_i ])
2453
+ T .writes (var_decode_intermediate_local [v_j , v_i ])
2454
+ var_decode_intermediate_local [v_j , v_i ] = T .Cast ("float32" , T .bitwise_and (T .shift_right (lv1158 [v_j // T .int64 (8 ), v_i ], T .Cast ("uint32" , v_j % T .int64 (8 ) * T .int64 (4 ))), T .uint32 (15 ))) * T .reinterpret ("float32" , T .shift_left (T .bitwise_and (lv1159 [v_j // T .int64 (32 ), v_i ], T .uint32 (65535 )), T .uint32 (16 ))) + T .reinterpret ("float32" , T .shift_left (T .bitwise_and (T .shift_right (lv1159 [v_j // T .int64 (32 ), v_i ], T .uint32 (16 )), T .uint32 (65535 )), T .uint32 (16 )))
2455
+ for k_0_2_k_1_fused in range (T .int64 (64 )):
2456
+ with T .block ("matmul_update" ):
2457
+ v_i0 = T .axis .spatial (T .int64 (1 ), T .int64 (0 ))
2458
+ v_i1 = T .axis .spatial (T .int64 (1 ), T .int64 (0 ))
2459
+ v_i2 = T .axis .spatial (T .int64 (4096 ), i0_i1_i2_0_fused * T .int64 (256 ) + i2_1 * T .int64 (256 ) + i2_2 )
2460
+ v_k = T .axis .reduce (T .int64 (11008 ), k_0_0 * T .int64 (5504 ) + k_0_1 * T .int64 (64 ) + k_0_2_k_1_fused )
2461
+ T .reads (var_matmul_intermediate_local [v_i0 , v_i1 , v_i2 ], lv6_shared [v_i0 , v_i1 , v_k ], var_decode_intermediate_local [v_k , v_i2 ])
2462
+ T .writes (var_matmul_intermediate_local [v_i0 , v_i1 , v_i2 ])
2463
+ var_matmul_intermediate_local [v_i0 , v_i1 , v_i2 ] = var_matmul_intermediate_local [v_i0 , v_i1 , v_i2 ] + lv6_shared [v_i0 , v_i1 , v_k ] * var_decode_intermediate_local [v_k , v_i2 ]
2464
+ for ax0 , ax1 , ax2 in T .grid (T .int64 (1 ), T .int64 (1 ), T .int64 (1 )):
2465
+ with T .block ("var_matmul_intermediate_local" ):
2466
+ v0 , v1 = T .axis .remap ("SS" , [ax0 , ax1 ])
2467
+ v2 = T .axis .spatial (T .int64 (4096 ), i0_i1_i2_0_fused * T .int64 (256 ) + i2_2 + ax2 )
2468
+ T .reads (lv4 [v0 , v1 , v2 ], var_matmul_intermediate_local [v0 , v1 , v2 ])
2469
+ T .writes (p_output0_intermediate [v0 , v1 , v2 ])
2470
+ p_output0_intermediate [v0 , v1 , v2 ] = lv4 [v0 , v1 , v2 ] + var_matmul_intermediate_local [v0 , v1 , v2 ]
2471
2471
# fmt: on
2472
2472
2473
2473
################################################
0 commit comments