2828DEBUG = False
2929
3030
31+ class TmaDescriptorBuilder :
32+ """A class that builds a TMA descriptor."""
33+
34+ def __init__ (self , swizzle , l2promo , oob , interleave , tma_box_shape , memref_ty ):
35+ self .swizzle = swizzle # mlir.nvgpu.TensorMapSwizzleKind
36+ self .l2promo = l2promo # mlir.nvgpu.TensorMapL2PromoKind
37+ self .oob = oob # mlir.nvgpu.TensorMapOOBKind
38+ self .interleave = interleave # mlir.nvgpu.TensorMapInterleaveKind
39+ self .tma_box_shape = tma_box_shape
40+ self .memref_ty = memref_ty # MemRefType
41+
42+ @property
43+ def tensormap_descriptor_ty (self ):
44+ """Returns a tensormap descriptor type."""
45+ tensorMemrefType = ir .MemRefType .get (
46+ self .tma_box_shape ,
47+ self .memref_ty .element_type ,
48+ memory_space = ir .Attribute .parse ("3" ),
49+ )
50+ return nvgpu .TensorMapDescriptorType .get (
51+ tensorMemrefType ,
52+ self .swizzle ,
53+ self .l2promo ,
54+ self .oob ,
55+ self .interleave ,
56+ )
57+
58+ def tma_descriptor_op (self , device_ptr ):
59+ """Returns a tensormap descriptor op."""
60+ tma_descriptor_ty = self .tensormap_descriptor_ty
61+ device_unranked_memref = memref .CastOp (
62+ ir .UnrankedMemRefType .get (
63+ self .memref_ty .element_type , self .memref_ty .memory_space
64+ ),
65+ device_ptr ,
66+ )
67+ tma_descriptor_op = nvgpu .TmaCreateDescriptorOp (
68+ tma_descriptor_ty , device_unranked_memref , map (c , self .tma_box_shape )
69+ )
70+ return tma_descriptor_op .result
71+
72+
3173def debug_print (fmt , * args , predicate = None , threadNumber = - 1 , forcePrint = False ):
3274 if not DEBUG and not forcePrint :
3375 return
@@ -162,28 +204,6 @@ def generate_matmul_ws(
162204 + str (num_stages )
163205 + ">"
164206 )
165- a_tma_desc_ty = ir .Type .parse (
166- "!nvgpu.tensormap.descriptor<tensor = memref<"
167- + str (BLOCK_M )
168- + "x"
169- + str (TMA_LAST_DIM_F16 )
170- + "x"
171- + str (a_elem_ty )
172- + ", "
173- + str (smem_space )
174- + ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
175- )
176- b_tma_desc_ty = ir .Type .parse (
177- "!nvgpu.tensormap.descriptor<tensor = memref<"
178- + str (BLOCK_K )
179- + "x"
180- + str (TMA_LAST_DIM_F16 )
181- + "x"
182- + str (b_elem_ty )
183- + ", "
184- + str (smem_space )
185- + ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
186- )
187207 acc_ty = ir .Type .parse (
188208 "!nvgpu.warpgroup.accumulator<fragmented=vector<"
189209 + str (BLOCK_M )
@@ -240,21 +260,26 @@ def generate_matmul_ws(
240260 t7 = gpu .wait (token_ty , [t6 ])
241261
242262 # Step 2. Create TMA Descriptors
243- tma_specs = [
244- (a_device , a_tma_desc_ty , a_tma_shape ),
245- (b_device , b_tma_desc_ty , b_tma_shape ),
246- ]
247- tma_descs = []
248- for x_device , tensor_map_ty , tile_shape in tma_specs :
249- x_unranked = memref .cast (
250- ir .UnrankedMemRefType .get (a_elem_ty , a_ty .memory_space ), x_device
251- )
252- tma_descs .append (
253- nvgpu .TmaCreateDescriptorOp (
254- tensor_map_ty , x_unranked , map (c , tile_shape )
255- ).result
256- )
257- a_tma_desc , b_tma_desc = tma_descs
263+ a_tma_desc = TmaDescriptorBuilder (
264+ nvgpu .TensorMapSwizzleKind .SWIZZLE_128B ,
265+ nvgpu .TensorMapL2PromoKind .L2PROMO_NONE ,
266+ nvgpu .TensorMapOOBKind .OOB_ZERO ,
267+ nvgpu .TensorMapInterleaveKind .INTERLEAVE_NONE ,
268+ a_tma_shape ,
269+ a_ty ,
270+ )
271+
272+ b_tma_desc = TmaDescriptorBuilder (
273+ nvgpu .TensorMapSwizzleKind .SWIZZLE_128B ,
274+ nvgpu .TensorMapL2PromoKind .L2PROMO_NONE ,
275+ nvgpu .TensorMapOOBKind .OOB_ZERO ,
276+ nvgpu .TensorMapInterleaveKind .INTERLEAVE_NONE ,
277+ b_tma_shape ,
278+ b_ty ,
279+ )
280+
281+ a_tma_desc_op = a_tma_desc .tma_descriptor_op (a_device )
282+ b_tma_desc_op = b_tma_desc .tma_descriptor_op (b_device )
258283
259284 # Step 3. Launch Kernel with 2 Warpgroups : 1 Producer, 1 Consumer
260285 cta_m = M // BLOCK_M
@@ -267,7 +292,7 @@ def generate_matmul_ws(
267292 [t7 ],
268293 * map (c , grid ),
269294 * map (c , block ),
270- dynamicSharedMemorySize = c (smem_size , ty = T .i32 ())
295+ dynamicSharedMemorySize = c (smem_size , ty = T .i32 ()),
271296 )
272297 launch_op .body .blocks .append (* ([T .index ()] * 12 ))
273298 with ir .InsertionPoint (launch_op .body .blocks [0 ]):
@@ -315,8 +340,8 @@ def generate_matmul_ws(
315340 gpu .barrier ()
316341
317342 # GPU Step 3. Prefetch TMA descriptors
318- nvgpu .tma_prefetch_descriptor (a_tma_desc , predicate = wgPrimaryThread )
319- nvgpu .tma_prefetch_descriptor (b_tma_desc , predicate = wgPrimaryThread )
343+ nvgpu .tma_prefetch_descriptor (a_tma_desc_op , predicate = wgPrimaryThread )
344+ nvgpu .tma_prefetch_descriptor (b_tma_desc_op , predicate = wgPrimaryThread )
320345
321346 ns = num_stages if num_stages == 1 else num_stages - 1
322347 # GPU Step 5. Producer Warpgroup (TMA Warpgroup)
@@ -405,15 +430,15 @@ def generate_matmul_ws(
405430 nvgpu .TmaAsyncLoadOp (
406431 a_tma_slice ,
407432 mbarTMA ,
408- a_tma_desc ,
433+ a_tma_desc_op ,
409434 coordinates = [coord , dimX ],
410435 mbarId = stage ,
411436 predicate = producerPrimaryThread ,
412437 )
413438 nvgpu .TmaAsyncLoadOp (
414439 b_tma_slice_1 ,
415440 mbarTMA ,
416- b_tma_desc ,
441+ b_tma_desc_op ,
417442 coordinates = [dimY , coord ],
418443 mbarId = stage ,
419444 predicate = producerPrimaryThread ,
@@ -422,7 +447,7 @@ def generate_matmul_ws(
422447 nvgpu .TmaAsyncLoadOp (
423448 b_tma_slice_2 ,
424449 mbarTMA ,
425- b_tma_desc ,
450+ b_tma_desc_op ,
426451 coordinates = [dimY2 , coord ],
427452 mbarId = stage ,
428453 predicate = producerPrimaryThread ,
@@ -514,10 +539,10 @@ def generate_matmul_ws(
514539 predicate = consumerPrimaryThread ,
515540 )
516541 da = nvgpu .WarpgroupGenerateDescriptorOp (
517- a_wgmma_ty , a_tile_slice , a_tma_desc
542+ a_wgmma_ty , a_tile_slice , a_tma_desc_op
518543 )
519544 db = nvgpu .WarpgroupGenerateDescriptorOp (
520- b_wgmma_ty , b_tile_slice , b_tma_desc
545+ b_wgmma_ty , b_tile_slice , b_tma_desc_op
521546 )
522547
523548 # Step 6.3.3. MMA
@@ -679,28 +704,6 @@ def generate_matmul_multistage(
679704 + str (num_stages )
680705 + ">"
681706 )
682- a_tma_desc_ty = ir .Type .parse (
683- "!nvgpu.tensormap.descriptor<tensor = memref<"
684- + str (BLOCK_M )
685- + "x"
686- + str (TMA_LAST_DIM_F16 )
687- + "x"
688- + str (a_elem_ty )
689- + ", "
690- + str (smem_space )
691- + ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
692- )
693- b_tma_desc_ty = ir .Type .parse (
694- "!nvgpu.tensormap.descriptor<tensor = memref<"
695- + str (BLOCK_K )
696- + "x"
697- + str (TMA_LAST_DIM_F16 )
698- + "x"
699- + str (b_elem_ty )
700- + ", "
701- + str (smem_space )
702- + ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
703- )
704707 acc_ty = ir .Type .parse (
705708 "!nvgpu.warpgroup.accumulator<fragmented=vector<"
706709 + str (BLOCK_M )
@@ -767,21 +770,26 @@ def generate_matmul_multistage(
767770 t7 = gpu .wait (token_ty , [t6 ])
768771
769772 # Step 2. Create TMA Descriptors
770- tma_specs = [
771- (a_device , a_tma_desc_ty , a_tma_shape ),
772- (b_device , b_tma_desc_ty , b_tma_shape ),
773- ]
774- tma_descs = []
775- for x_device , tensor_map_ty , tile_shape in tma_specs :
776- x_unranked = memref .cast (
777- ir .UnrankedMemRefType .get (a_elem_ty , a_ty .memory_space ), x_device
778- )
779- tma_descs .append (
780- nvgpu .TmaCreateDescriptorOp (
781- tensor_map_ty , x_unranked , map (c , tile_shape )
782- ).result
783- )
784- a_tma_desc , b_tma_desc = tma_descs
773+ a_tma_desc = TmaDescriptorBuilder (
774+ nvgpu .TensorMapSwizzleKind .SWIZZLE_128B ,
775+ nvgpu .TensorMapL2PromoKind .L2PROMO_NONE ,
776+ nvgpu .TensorMapOOBKind .OOB_ZERO ,
777+ nvgpu .TensorMapInterleaveKind .INTERLEAVE_NONE ,
778+ a_tma_shape ,
779+ a_ty ,
780+ )
781+
782+ b_tma_desc = TmaDescriptorBuilder (
783+ nvgpu .TensorMapSwizzleKind .SWIZZLE_128B ,
784+ nvgpu .TensorMapL2PromoKind .L2PROMO_NONE ,
785+ nvgpu .TensorMapOOBKind .OOB_ZERO ,
786+ nvgpu .TensorMapInterleaveKind .INTERLEAVE_NONE ,
787+ b_tma_shape ,
788+ b_ty ,
789+ )
790+
791+ a_tma_desc_op = a_tma_desc .tma_descriptor_op (a_device )
792+ b_tma_desc_op = b_tma_desc .tma_descriptor_op (b_device )
785793
786794 # Step 3. Launch Kernel with 1 Warpgroup
787795 cta_m = M // BLOCK_M
@@ -794,7 +802,7 @@ def generate_matmul_multistage(
794802 [t7 ],
795803 * map (c , grid ),
796804 * map (c , block ),
797- dynamicSharedMemorySize = c (smem_size , ty = T .i32 ())
805+ dynamicSharedMemorySize = c (smem_size , ty = T .i32 ()),
798806 )
799807 launch_op .body .blocks .append (* ([T .index ()] * 12 ))
800808 with ir .InsertionPoint (launch_op .body .blocks [0 ]):
@@ -819,8 +827,8 @@ def generate_matmul_multistage(
819827 gpu .barrier ()
820828
821829 # GPU Step 2. Prefetch TMA descriptors
822- nvgpu .tma_prefetch_descriptor (a_tma_desc , predicate = primaryThread )
823- nvgpu .tma_prefetch_descriptor (b_tma_desc , predicate = primaryThread )
830+ nvgpu .tma_prefetch_descriptor (a_tma_desc_op , predicate = primaryThread )
831+ nvgpu .tma_prefetch_descriptor (b_tma_desc_op , predicate = primaryThread )
824832
825833 # GPU Step 3. Prologue (global memory --> shared memory)
826834 ns = num_stages if num_stages == 1 else num_stages - 1
@@ -880,23 +888,23 @@ def generate_matmul_multistage(
880888 nvgpu .TmaAsyncLoadOp (
881889 a_tma_slice ,
882890 mbarTMA ,
883- a_tma_desc ,
891+ a_tma_desc_op ,
884892 coordinates = [coord , dimX ],
885893 mbarId = iv ,
886894 predicate = primaryThread ,
887895 )
888896 nvgpu .TmaAsyncLoadOp (
889897 b_tma_slice_1 ,
890898 mbarTMA ,
891- b_tma_desc ,
899+ b_tma_desc_op ,
892900 coordinates = [dimY , coord ],
893901 mbarId = iv ,
894902 predicate = primaryThread ,
895903 )
896904 nvgpu .TmaAsyncLoadOp (
897905 b_tma_slice_2 ,
898906 mbarTMA ,
899- b_tma_desc ,
907+ b_tma_desc_op ,
900908 coordinates = [dimY2 , coord ],
901909 mbarId = iv ,
902910 predicate = primaryThread ,
@@ -972,10 +980,10 @@ def generate_matmul_multistage(
972980 predicate = primaryThread ,
973981 )
974982 da = nvgpu .WarpgroupGenerateDescriptorOp (
975- a_wgmma_ty , a_tile_slice , a_tma_desc
983+ a_wgmma_ty , a_tile_slice , a_tma_desc_op
976984 )
977985 db = nvgpu .WarpgroupGenerateDescriptorOp (
978- b_wgmma_ty , b_tile_slice , b_tma_desc
986+ b_wgmma_ty , b_tile_slice , b_tma_desc_op
979987 )
980988
981989 # Step 4.3. MMA
@@ -1060,15 +1068,15 @@ def generate_matmul_multistage(
10601068 nvgpu .TmaAsyncLoadOp (
10611069 a_tma_slice ,
10621070 mbarTMA ,
1063- a_tma_desc ,
1071+ a_tma_desc_op ,
10641072 coordinates = [coord , dimX ],
10651073 mbarId = nextSlot ,
10661074 predicate = p ,
10671075 )
10681076 nvgpu .TmaAsyncLoadOp (
10691077 b_tma_slice_1 ,
10701078 mbarTMA ,
1071- b_tma_desc ,
1079+ b_tma_desc_op ,
10721080 coordinates = [dimY , coord ],
10731081 mbarId = nextSlot ,
10741082 predicate = p ,
@@ -1077,7 +1085,7 @@ def generate_matmul_multistage(
10771085 nvgpu .TmaAsyncLoadOp (
10781086 b_tma_slice_2 ,
10791087 mbarTMA ,
1080- b_tma_desc ,
1088+ b_tma_desc_op ,
10811089 coordinates = [dimY2 , coord ],
10821090 mbarId = nextSlot ,
10831091 predicate = p ,
0 commit comments