2929from jax ._src .lib .mlir .dialects import nvvm
3030from jax ._src .lib .mlir .dialects import scf
3131from jax ._src .lib .mlir .dialects import vector
32- from jax .experimental .mosaic .gpu import dialect as mgpu # pylint: disable=g-importing-member
33- from jax .experimental .mosaic .gpu import gpu_address_space_to_nvptx # pylint: disable=g-importing-member,g-multiple-import
34- from jax .experimental .mosaic .gpu import infer_layout # pylint: disable=g-importing-member,g-multiple-import
35- from jax .experimental .mosaic .gpu import lower_mgpu_dialect # pylint: disable=g-importing-member,g-multiple-import
36- from jax .experimental .mosaic .gpu import strided_fragmented_layout # pylint: disable=g-importing-member
32+ import jax .experimental .mosaic .gpu as mgpu
3733
38- _cext = mgpu ._cext if mgpu is not None else None
34+ _cext = mgpu .dialect . _cext if mgpu . dialect is not None else None
3935
4036
4137config .parse_flags_with_absl ()
@@ -45,7 +41,7 @@ def _make_ir_context():
4541 context = ir .Context ()
4642 context .append_dialect_registry (mlir_interpreter .upstream_dialects )
4743 context .load_all_available_dialects ()
48- mgpu .register_dialect (context )
44+ mgpu .dialect . register_dialect (context )
4945 return context
5046
5147
@@ -76,7 +72,7 @@ def is_mosaic_gpu_op(op: ir.OpView) -> bool:
7672
7773
7874def workgroup_ptr_ty () -> ir .Type :
79- workgroup_nvptx_address_space = gpu_address_space_to_nvptx (
75+ workgroup_nvptx_address_space = mgpu . gpu_address_space_to_nvptx (
8076 gpu .AddressSpace .Workgroup
8177 )
8278 return ir .Type .parse (f"!llvm.ptr<{ workgroup_nvptx_address_space } >" )
@@ -85,7 +81,7 @@ def workgroup_ptr_ty() -> ir.Type:
8581class MosaicGpuTest (parameterized .TestCase ):
8682
8783 def setUp (self ):
88- if mgpu is None :
84+ if mgpu . dialect is None :
8985 raise self .skipTest ("Test requires Mosaic GPU dialect" )
9086 super ().setUp ()
9187 self .enter_context (_make_ir_context ())
@@ -100,7 +96,7 @@ def test_dialect_module_is_loaded(self):
10096
10197 def test_initialize_barrier_op_result_memref_must_wrap_barriers (self ):
10298 with ir .InsertionPoint (self .module .body ):
103- mgpu .initialize_barrier (
99+ mgpu .dialect . initialize_barrier (
104100 ir .MemRefType .get ((1 , 2 ), ir .F32Type .get ()),
105101 llvm .UndefOp (workgroup_ptr_ty ()),
106102 arrival_count = 1 ,
@@ -112,7 +108,7 @@ def test_initialize_barrier_op_result_memref_must_wrap_barriers(self):
112108
113109 def test_initialize_barrier_op_arrival_count_must_be_strictly_positive (self ):
114110 with ir .InsertionPoint (self .module .body ):
115- mgpu .initialize_barrier (
111+ mgpu .dialect . initialize_barrier (
116112 ir .MemRefType .get ((1 , 2 ), ir .Type .parse ("!mosaic_gpu.barrier" )),
117113 llvm .UndefOp (workgroup_ptr_ty ()),
118114 arrival_count = 0 ,
@@ -122,7 +118,7 @@ def test_initialize_barrier_op_arrival_count_must_be_strictly_positive(self):
122118
123119 def test_initialize_barrier_op_with_a_non_shared_base_pointer_fails (self ):
124120 with ir .InsertionPoint (self .module .body ):
125- mgpu .initialize_barrier (
121+ mgpu .dialect . initialize_barrier (
126122 ir .MemRefType .get ((1 , 2 ), ir .Type .parse ("!mosaic_gpu.barrier" )),
127123 llvm .UndefOp (ir .Type .parse (f"!llvm.ptr<{ 0 } >" )),
128124 arrival_count = 1 ,
@@ -132,14 +128,14 @@ def test_initialize_barrier_op_with_a_non_shared_base_pointer_fails(self):
132128
133129 def test_initialize_barrier_op_with_a_positive_arrival_count_passes (self ):
134130 with ir .InsertionPoint (self .module .body ):
135- mgpu .initialize_barrier (
131+ mgpu .dialect . initialize_barrier (
136132 ir .MemRefType .get ((1 , 2 ), ir .Type .parse ("!mosaic_gpu.barrier" )),
137133 llvm .UndefOp (workgroup_ptr_ty ()),
138134 arrival_count = 1 ,
139135 )
140136 self .assertTrue (self .module .operation .verify ())
141137 self .assertIsInstance (
142- self .module .body .operations [1 ], mgpu .InitializeBarrierOp
138+ self .module .body .operations [1 ], mgpu .dialect . InitializeBarrierOp
143139 )
144140
145141 def test_async_load_op_dest_must_be_contiguous (self ):
@@ -156,7 +152,7 @@ def test_async_load_op_dest_must_be_contiguous(self):
156152 ir .IntegerType .get_signless (32 ),
157153 name = "async_load" ,
158154 )(
159- lambda source , destination , barrier , * indices : mgpu .async_load (
155+ lambda source , destination , barrier , * indices : mgpu .dialect . async_load (
160156 source ,
161157 destination ,
162158 barrier ,
@@ -183,7 +179,7 @@ def test_async_load_op_source_and_dest_must_have_same_element_type(self):
183179 ir .IntegerType .get_signless (32 ),
184180 name = "async_load" ,
185181 )(
186- lambda source , destination , barrier , * indices : mgpu .async_load (
182+ lambda source , destination , barrier , * indices : mgpu .dialect . async_load (
187183 source ,
188184 destination ,
189185 barrier ,
@@ -210,7 +206,7 @@ def test_async_load_op_slice_lengths_must_be_larger_than_minus_two(self):
210206 ir .IntegerType .get_signless (32 ),
211207 name = "async_load" ,
212208 )(
213- lambda source , destination , barrier , * indices : mgpu .async_load (
209+ lambda source , destination , barrier , * indices : mgpu .dialect . async_load (
214210 source ,
215211 destination ,
216212 barrier ,
@@ -238,7 +234,7 @@ def test_async_load_op_source_and_dest_ranks_must_match_with_collapse(self):
238234 ir .IntegerType .get_signless (32 ),
239235 name = "async_load" ,
240236 )(
241- lambda source , destination , barrier , * indices : mgpu .async_load (
237+ lambda source , destination , barrier , * indices : mgpu .dialect . async_load (
242238 source ,
243239 destination ,
244240 barrier ,
@@ -264,7 +260,7 @@ def test_async_load_op_indices_size_must_match_source_rank(self):
264260 ir .IntegerType .get_signless (32 ),
265261 name = "async_load" ,
266262 )(
267- lambda source , destination , barrier , * indices : mgpu .async_load (
263+ lambda source , destination , barrier , * indices : mgpu .dialect . async_load (
268264 source ,
269265 destination ,
270266 barrier ,
@@ -290,7 +286,7 @@ def test_async_load_op_slice_lengths_size_must_match_source_rank(self):
290286 ir .IntegerType .get_signless (32 ),
291287 name = "async_load" ,
292288 )(
293- lambda source , destination , barrier , * indices : mgpu .async_load (
289+ lambda source , destination , barrier , * indices : mgpu .dialect . async_load (
294290 source ,
295291 destination ,
296292 barrier ,
@@ -316,7 +312,7 @@ def test_async_load_op_slice_collective_must_be_unique(self):
316312 ir .IntegerType .get_signless (32 ),
317313 name = "async_load" ,
318314 )(
319- lambda source , destination , barrier , * indices : mgpu .async_load (
315+ lambda source , destination , barrier , * indices : mgpu .dialect . async_load (
320316 source ,
321317 destination ,
322318 barrier ,
@@ -325,10 +321,10 @@ def test_async_load_op_slice_collective_must_be_unique(self):
325321 transforms = ir .ArrayAttr .get ([]),
326322 collective = ir .ArrayAttr .get ([
327323 ir .Attribute .parse (
328- f"#mosaic_gpu.dim<{ mgpu .Dimension .x .name } >"
324+ f"#mosaic_gpu.dim<{ mgpu .dialect . Dimension .x .name } >"
329325 ),
330326 ir .Attribute .parse (
331- f"#mosaic_gpu.dim<{ mgpu .Dimension .x .name } >"
327+ f"#mosaic_gpu.dim<{ mgpu .dialect . Dimension .x .name } >"
332328 ),
333329 ]),
334330 )
@@ -353,7 +349,7 @@ def test_async_store_op_source_must_be_contiguous(self):
353349 ir .IntegerType .get_signless (32 ),
354350 name = "async_store" ,
355351 )(
356- lambda source , destination , * indices : mgpu .async_store (
352+ lambda source , destination , * indices : mgpu .dialect . async_store (
357353 source ,
358354 destination ,
359355 indices ,
@@ -377,7 +373,7 @@ def test_async_store_op_source_and_dest_must_have_same_element_type(self):
377373 ir .IntegerType .get_signless (32 ),
378374 name = "async_store" ,
379375 )(
380- lambda source , destination , * indices : mgpu .async_store (
376+ lambda source , destination , * indices : mgpu .dialect . async_store (
381377 source ,
382378 destination ,
383379 indices ,
@@ -401,7 +397,7 @@ def test_async_store_op_slice_lengths_must_be_larger_than_minus_two(self):
401397 ir .IntegerType .get_signless (32 ),
402398 name = "async_store" ,
403399 )(
404- lambda source , destination , * indices : mgpu .async_store (
400+ lambda source , destination , * indices : mgpu .dialect . async_store (
405401 source ,
406402 destination ,
407403 indices ,
@@ -426,7 +422,7 @@ def test_async_store_op_source_and_dest_ranks_must_match_with_collapse(self):
426422 ir .IntegerType .get_signless (32 ),
427423 name = "async_store" ,
428424 )(
429- lambda source , destination , * indices : mgpu .async_store (
425+ lambda source , destination , * indices : mgpu .dialect . async_store (
430426 source ,
431427 destination ,
432428 indices ,
@@ -449,7 +445,7 @@ def test_async_store_op_indices_size_must_match_destination_rank(self):
449445 ir .IntegerType .get_signless (32 ),
450446 name = "async_store" ,
451447 )(
452- lambda source , destination , * indices : mgpu .async_store (
448+ lambda source , destination , * indices : mgpu .dialect . async_store (
453449 source ,
454450 destination ,
455451 indices ,
@@ -472,7 +468,7 @@ def test_async_store_op_slice_lengths_size_must_match_source_rank(self):
472468 ir .IntegerType .get_signless (32 ),
473469 name = "async_store" ,
474470 )(
475- lambda source , destination , * indices : mgpu .async_store (
471+ lambda source , destination , * indices : mgpu .dialect . async_store (
476472 source ,
477473 destination ,
478474 indices ,
@@ -496,7 +492,7 @@ def test_wgmma_types_match(self):
496492 ir .MemRefType .get ([4 , 5 , 32 , 32 ], ir .BF16Type .get ()),
497493 name = "wgmma" ,
498494 )(
499- lambda accumulator , a , b : mgpu .wgmma (
495+ lambda accumulator , a , b : mgpu .dialect . wgmma (
500496 accumulator ,
501497 a ,
502498 b ,
@@ -518,7 +514,7 @@ def test_wgmma_b_rank_is_4(self):
518514 ir .MemRefType .get ([5 , 32 , 32 ], ir .BF16Type .get ()),
519515 name = "wgmma" ,
520516 )(
521- lambda accumulator , a , b : mgpu .wgmma (
517+ lambda accumulator , a , b : mgpu .dialect . wgmma (
522518 accumulator ,
523519 a ,
524520 b ,
@@ -540,7 +536,7 @@ def test_wgmma_b_shape_dim_3(self):
540536 ir .MemRefType .get ([4 , 5 , 32 , 16 ], ir .BF16Type .get ()),
541537 name = "wgmma" ,
542538 )(
543- lambda accumulator , a , b : mgpu .wgmma (
539+ lambda accumulator , a , b : mgpu .dialect . wgmma (
544540 accumulator ,
545541 a ,
546542 b ,
@@ -563,7 +559,7 @@ def test_wgmma_b_shape_dim_2(self):
563559 ir .MemRefType .get ([4 , 5 , 64 , 32 ], ir .BF16Type .get ()),
564560 name = "wgmma" ,
565561 )(
566- lambda accumulator , a , b : mgpu .wgmma (
562+ lambda accumulator , a , b : mgpu .dialect . wgmma (
567563 accumulator ,
568564 a ,
569565 b ,
@@ -585,12 +581,12 @@ class DialectLoweringTest(MosaicGpuTest):
585581
586582 def test_lowering_removes_mosaic_gpu_ops (self ):
587583 with ir .InsertionPoint (self .module .body ):
588- mgpu .initialize_barrier (
584+ mgpu .dialect . initialize_barrier (
589585 ir .MemRefType .get ((1 , 2 ), ir .Type .parse ("!mosaic_gpu.barrier" )),
590586 llvm .UndefOp (workgroup_ptr_ty ()),
591587 arrival_count = 1 ,
592588 )
593- lower_mgpu_dialect (self .module )
589+ mgpu . lower_mgpu_dialect (self .module )
594590
595591 self .assertEmpty (
596592 list (filter (is_mosaic_gpu_op , self .module .body .operations ))
@@ -602,13 +598,13 @@ def test_lowering_traverses_regions_correctly(self):
602598 cst_true = arith .constant (bool_type , ir .IntegerAttr .get (bool_type , 1 ))
603599 if_op = scf .IfOp (cst_true )
604600 with ir .InsertionPoint (if_op .then_block ):
605- mgpu .initialize_barrier (
601+ mgpu .dialect . initialize_barrier (
606602 ir .MemRefType .get ((1 , 2 ), ir .Type .parse ("!mosaic_gpu.barrier" )),
607603 llvm .UndefOp (workgroup_ptr_ty ()),
608604 arrival_count = 1 ,
609605 )
610606 scf .yield_ ([])
611- lower_mgpu_dialect (self .module )
607+ mgpu . lower_mgpu_dialect (self .module )
612608
613609 self .assertEmpty (
614610 list (filter (is_mosaic_gpu_op , if_op .then_block .operations ))
@@ -620,7 +616,7 @@ def test_initialize_barrier_op_lowering_rule(self):
620616 arrival_count = 1337
621617
622618 with ir .InsertionPoint (self .module .body ):
623- barriers_ref = mgpu .initialize_barrier (
619+ barriers_ref = mgpu .dialect . initialize_barrier (
624620 ir .MemRefType .get (shape , ir .Type .parse ("!mosaic_gpu.barrier" )),
625621 llvm .UndefOp (workgroup_ptr_ty ()),
626622 arrival_count = arrival_count ,
@@ -630,7 +626,7 @@ def test_initialize_barrier_op_lowering_rule(self):
630626 memref .copy (barriers_ref , barriers_ref )
631627
632628 self .assertTrue (self .module .operation .verify ())
633- lower_mgpu_dialect (self .module )
629+ mgpu . lower_mgpu_dialect (self .module )
634630 self .assertTrue (self .module .operation .verify ())
635631
636632 all_mbarrier_init_shared_ops = find_if (
@@ -658,7 +654,7 @@ def test_lowering_vector_op_without_layout_fails(self):
658654 with self .assertRaisesRegex (
659655 ValueError , "missing a layout and can not be lowered"
660656 ):
661- lower_mgpu_dialect (self .module )
657+ mgpu . lower_mgpu_dialect (self .module )
662658
663659 def test_lowering_eliminates_layouts (self ):
664660 shape = (4 , 128 )
@@ -669,10 +665,10 @@ def test_lowering_eliminates_layouts(self):
669665 ty = ir .VectorType .get (shape , elt_ty )
670666 load = vector .load (ty , ref , [zero_index , zero_index ])
671667 load .owner .attributes ["out_layouts" ] = ir .ArrayAttr .get (
672- [strided_fragmented_layout ()]
668+ [mgpu . strided_fragmented_layout ()]
673669 )
674670
675- lower_mgpu_dialect (self .module )
671+ mgpu . lower_mgpu_dialect (self .module )
676672
677673 all_ops_with_layouts = find_if (
678674 self .module ,
@@ -692,8 +688,8 @@ def test_lowering_vector_load_and_store_ops(self):
692688 array = vector .load (ty , ref , [zero_index , zero_index ])
693689 vector .store (array , ref , [zero_index , zero_index ])
694690
695- infer_layout (self .module )
696- lower_mgpu_dialect (self .module )
691+ mgpu . infer_layout (self .module )
692+ mgpu . lower_mgpu_dialect (self .module )
697693
698694 all_loads = find_if (
699695 self .module ,
0 commit comments