Skip to content

Commit 6a03ea3

Browse files
bchetiouiGoogle-ML-Automation
authored andcommitted
[Mosaic GPU] Clean up imports in gpu_dialect_test.py.
PiperOrigin-RevId: 707549269
1 parent 3d54d03 commit 6a03ea3

File tree

1 file changed

+40
-44
lines changed

1 file changed

+40
-44
lines changed

tests/mosaic/gpu_dialect_test.py

Lines changed: 40 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,9 @@
2929
from jax._src.lib.mlir.dialects import nvvm
3030
from jax._src.lib.mlir.dialects import scf
3131
from jax._src.lib.mlir.dialects import vector
32-
from jax.experimental.mosaic.gpu import dialect as mgpu # pylint: disable=g-importing-member
33-
from jax.experimental.mosaic.gpu import gpu_address_space_to_nvptx # pylint: disable=g-importing-member,g-multiple-import
34-
from jax.experimental.mosaic.gpu import infer_layout # pylint: disable=g-importing-member,g-multiple-import
35-
from jax.experimental.mosaic.gpu import lower_mgpu_dialect # pylint: disable=g-importing-member,g-multiple-import
36-
from jax.experimental.mosaic.gpu import strided_fragmented_layout # pylint: disable=g-importing-member
32+
import jax.experimental.mosaic.gpu as mgpu
3733

38-
_cext = mgpu._cext if mgpu is not None else None
34+
_cext = mgpu.dialect._cext if mgpu.dialect is not None else None
3935

4036

4137
config.parse_flags_with_absl()
@@ -45,7 +41,7 @@ def _make_ir_context():
4541
context = ir.Context()
4642
context.append_dialect_registry(mlir_interpreter.upstream_dialects)
4743
context.load_all_available_dialects()
48-
mgpu.register_dialect(context)
44+
mgpu.dialect.register_dialect(context)
4945
return context
5046

5147

@@ -76,7 +72,7 @@ def is_mosaic_gpu_op(op: ir.OpView) -> bool:
7672

7773

7874
def workgroup_ptr_ty() -> ir.Type:
79-
workgroup_nvptx_address_space = gpu_address_space_to_nvptx(
75+
workgroup_nvptx_address_space = mgpu.gpu_address_space_to_nvptx(
8076
gpu.AddressSpace.Workgroup
8177
)
8278
return ir.Type.parse(f"!llvm.ptr<{workgroup_nvptx_address_space}>")
@@ -85,7 +81,7 @@ def workgroup_ptr_ty() -> ir.Type:
8581
class MosaicGpuTest(parameterized.TestCase):
8682

8783
def setUp(self):
88-
if mgpu is None:
84+
if mgpu.dialect is None:
8985
raise self.skipTest("Test requires Mosaic GPU dialect")
9086
super().setUp()
9187
self.enter_context(_make_ir_context())
@@ -100,7 +96,7 @@ def test_dialect_module_is_loaded(self):
10096

10197
def test_initialize_barrier_op_result_memref_must_wrap_barriers(self):
10298
with ir.InsertionPoint(self.module.body):
103-
mgpu.initialize_barrier(
99+
mgpu.dialect.initialize_barrier(
104100
ir.MemRefType.get((1, 2), ir.F32Type.get()),
105101
llvm.UndefOp(workgroup_ptr_ty()),
106102
arrival_count=1,
@@ -112,7 +108,7 @@ def test_initialize_barrier_op_result_memref_must_wrap_barriers(self):
112108

113109
def test_initialize_barrier_op_arrival_count_must_be_strictly_positive(self):
114110
with ir.InsertionPoint(self.module.body):
115-
mgpu.initialize_barrier(
111+
mgpu.dialect.initialize_barrier(
116112
ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")),
117113
llvm.UndefOp(workgroup_ptr_ty()),
118114
arrival_count=0,
@@ -122,7 +118,7 @@ def test_initialize_barrier_op_arrival_count_must_be_strictly_positive(self):
122118

123119
def test_initialize_barrier_op_with_a_non_shared_base_pointer_fails(self):
124120
with ir.InsertionPoint(self.module.body):
125-
mgpu.initialize_barrier(
121+
mgpu.dialect.initialize_barrier(
126122
ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")),
127123
llvm.UndefOp(ir.Type.parse(f"!llvm.ptr<{0}>")),
128124
arrival_count=1,
@@ -132,14 +128,14 @@ def test_initialize_barrier_op_with_a_non_shared_base_pointer_fails(self):
132128

133129
def test_initialize_barrier_op_with_a_positive_arrival_count_passes(self):
134130
with ir.InsertionPoint(self.module.body):
135-
mgpu.initialize_barrier(
131+
mgpu.dialect.initialize_barrier(
136132
ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")),
137133
llvm.UndefOp(workgroup_ptr_ty()),
138134
arrival_count=1,
139135
)
140136
self.assertTrue(self.module.operation.verify())
141137
self.assertIsInstance(
142-
self.module.body.operations[1], mgpu.InitializeBarrierOp
138+
self.module.body.operations[1], mgpu.dialect.InitializeBarrierOp
143139
)
144140

145141
def test_async_load_op_dest_must_be_contiguous(self):
@@ -156,7 +152,7 @@ def test_async_load_op_dest_must_be_contiguous(self):
156152
ir.IntegerType.get_signless(32),
157153
name="async_load",
158154
)(
159-
lambda source, destination, barrier, *indices: mgpu.async_load(
155+
lambda source, destination, barrier, *indices: mgpu.dialect.async_load(
160156
source,
161157
destination,
162158
barrier,
@@ -183,7 +179,7 @@ def test_async_load_op_source_and_dest_must_have_same_element_type(self):
183179
ir.IntegerType.get_signless(32),
184180
name="async_load",
185181
)(
186-
lambda source, destination, barrier, *indices: mgpu.async_load(
182+
lambda source, destination, barrier, *indices: mgpu.dialect.async_load(
187183
source,
188184
destination,
189185
barrier,
@@ -210,7 +206,7 @@ def test_async_load_op_slice_lengths_must_be_larger_than_minus_two(self):
210206
ir.IntegerType.get_signless(32),
211207
name="async_load",
212208
)(
213-
lambda source, destination, barrier, *indices: mgpu.async_load(
209+
lambda source, destination, barrier, *indices: mgpu.dialect.async_load(
214210
source,
215211
destination,
216212
barrier,
@@ -238,7 +234,7 @@ def test_async_load_op_source_and_dest_ranks_must_match_with_collapse(self):
238234
ir.IntegerType.get_signless(32),
239235
name="async_load",
240236
)(
241-
lambda source, destination, barrier, *indices: mgpu.async_load(
237+
lambda source, destination, barrier, *indices: mgpu.dialect.async_load(
242238
source,
243239
destination,
244240
barrier,
@@ -264,7 +260,7 @@ def test_async_load_op_indices_size_must_match_source_rank(self):
264260
ir.IntegerType.get_signless(32),
265261
name="async_load",
266262
)(
267-
lambda source, destination, barrier, *indices: mgpu.async_load(
263+
lambda source, destination, barrier, *indices: mgpu.dialect.async_load(
268264
source,
269265
destination,
270266
barrier,
@@ -290,7 +286,7 @@ def test_async_load_op_slice_lengths_size_must_match_source_rank(self):
290286
ir.IntegerType.get_signless(32),
291287
name="async_load",
292288
)(
293-
lambda source, destination, barrier, *indices: mgpu.async_load(
289+
lambda source, destination, barrier, *indices: mgpu.dialect.async_load(
294290
source,
295291
destination,
296292
barrier,
@@ -316,7 +312,7 @@ def test_async_load_op_slice_collective_must_be_unique(self):
316312
ir.IntegerType.get_signless(32),
317313
name="async_load",
318314
)(
319-
lambda source, destination, barrier, *indices: mgpu.async_load(
315+
lambda source, destination, barrier, *indices: mgpu.dialect.async_load(
320316
source,
321317
destination,
322318
barrier,
@@ -325,10 +321,10 @@ def test_async_load_op_slice_collective_must_be_unique(self):
325321
transforms=ir.ArrayAttr.get([]),
326322
collective=ir.ArrayAttr.get([
327323
ir.Attribute.parse(
328-
f"#mosaic_gpu.dim<{mgpu.Dimension.x.name}>"
324+
f"#mosaic_gpu.dim<{mgpu.dialect.Dimension.x.name}>"
329325
),
330326
ir.Attribute.parse(
331-
f"#mosaic_gpu.dim<{mgpu.Dimension.x.name}>"
327+
f"#mosaic_gpu.dim<{mgpu.dialect.Dimension.x.name}>"
332328
),
333329
]),
334330
)
@@ -353,7 +349,7 @@ def test_async_store_op_source_must_be_contiguous(self):
353349
ir.IntegerType.get_signless(32),
354350
name="async_store",
355351
)(
356-
lambda source, destination, *indices: mgpu.async_store(
352+
lambda source, destination, *indices: mgpu.dialect.async_store(
357353
source,
358354
destination,
359355
indices,
@@ -377,7 +373,7 @@ def test_async_store_op_source_and_dest_must_have_same_element_type(self):
377373
ir.IntegerType.get_signless(32),
378374
name="async_store",
379375
)(
380-
lambda source, destination, *indices: mgpu.async_store(
376+
lambda source, destination, *indices: mgpu.dialect.async_store(
381377
source,
382378
destination,
383379
indices,
@@ -401,7 +397,7 @@ def test_async_store_op_slice_lengths_must_be_larger_than_minus_two(self):
401397
ir.IntegerType.get_signless(32),
402398
name="async_store",
403399
)(
404-
lambda source, destination, *indices: mgpu.async_store(
400+
lambda source, destination, *indices: mgpu.dialect.async_store(
405401
source,
406402
destination,
407403
indices,
@@ -426,7 +422,7 @@ def test_async_store_op_source_and_dest_ranks_must_match_with_collapse(self):
426422
ir.IntegerType.get_signless(32),
427423
name="async_store",
428424
)(
429-
lambda source, destination, *indices: mgpu.async_store(
425+
lambda source, destination, *indices: mgpu.dialect.async_store(
430426
source,
431427
destination,
432428
indices,
@@ -449,7 +445,7 @@ def test_async_store_op_indices_size_must_match_destination_rank(self):
449445
ir.IntegerType.get_signless(32),
450446
name="async_store",
451447
)(
452-
lambda source, destination, *indices: mgpu.async_store(
448+
lambda source, destination, *indices: mgpu.dialect.async_store(
453449
source,
454450
destination,
455451
indices,
@@ -472,7 +468,7 @@ def test_async_store_op_slice_lengths_size_must_match_source_rank(self):
472468
ir.IntegerType.get_signless(32),
473469
name="async_store",
474470
)(
475-
lambda source, destination, *indices: mgpu.async_store(
471+
lambda source, destination, *indices: mgpu.dialect.async_store(
476472
source,
477473
destination,
478474
indices,
@@ -496,7 +492,7 @@ def test_wgmma_types_match(self):
496492
ir.MemRefType.get([4, 5, 32, 32], ir.BF16Type.get()),
497493
name="wgmma",
498494
)(
499-
lambda accumulator, a, b: mgpu.wgmma(
495+
lambda accumulator, a, b: mgpu.dialect.wgmma(
500496
accumulator,
501497
a,
502498
b,
@@ -518,7 +514,7 @@ def test_wgmma_b_rank_is_4(self):
518514
ir.MemRefType.get([5, 32, 32], ir.BF16Type.get()),
519515
name="wgmma",
520516
)(
521-
lambda accumulator, a, b: mgpu.wgmma(
517+
lambda accumulator, a, b: mgpu.dialect.wgmma(
522518
accumulator,
523519
a,
524520
b,
@@ -540,7 +536,7 @@ def test_wgmma_b_shape_dim_3(self):
540536
ir.MemRefType.get([4, 5, 32, 16], ir.BF16Type.get()),
541537
name="wgmma",
542538
)(
543-
lambda accumulator, a, b: mgpu.wgmma(
539+
lambda accumulator, a, b: mgpu.dialect.wgmma(
544540
accumulator,
545541
a,
546542
b,
@@ -563,7 +559,7 @@ def test_wgmma_b_shape_dim_2(self):
563559
ir.MemRefType.get([4, 5, 64, 32], ir.BF16Type.get()),
564560
name="wgmma",
565561
)(
566-
lambda accumulator, a, b: mgpu.wgmma(
562+
lambda accumulator, a, b: mgpu.dialect.wgmma(
567563
accumulator,
568564
a,
569565
b,
@@ -585,12 +581,12 @@ class DialectLoweringTest(MosaicGpuTest):
585581

586582
def test_lowering_removes_mosaic_gpu_ops(self):
587583
with ir.InsertionPoint(self.module.body):
588-
mgpu.initialize_barrier(
584+
mgpu.dialect.initialize_barrier(
589585
ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")),
590586
llvm.UndefOp(workgroup_ptr_ty()),
591587
arrival_count=1,
592588
)
593-
lower_mgpu_dialect(self.module)
589+
mgpu.lower_mgpu_dialect(self.module)
594590

595591
self.assertEmpty(
596592
list(filter(is_mosaic_gpu_op, self.module.body.operations))
@@ -602,13 +598,13 @@ def test_lowering_traverses_regions_correctly(self):
602598
cst_true = arith.constant(bool_type, ir.IntegerAttr.get(bool_type, 1))
603599
if_op = scf.IfOp(cst_true)
604600
with ir.InsertionPoint(if_op.then_block):
605-
mgpu.initialize_barrier(
601+
mgpu.dialect.initialize_barrier(
606602
ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")),
607603
llvm.UndefOp(workgroup_ptr_ty()),
608604
arrival_count=1,
609605
)
610606
scf.yield_([])
611-
lower_mgpu_dialect(self.module)
607+
mgpu.lower_mgpu_dialect(self.module)
612608

613609
self.assertEmpty(
614610
list(filter(is_mosaic_gpu_op, if_op.then_block.operations))
@@ -620,7 +616,7 @@ def test_initialize_barrier_op_lowering_rule(self):
620616
arrival_count = 1337
621617

622618
with ir.InsertionPoint(self.module.body):
623-
barriers_ref = mgpu.initialize_barrier(
619+
barriers_ref = mgpu.dialect.initialize_barrier(
624620
ir.MemRefType.get(shape, ir.Type.parse("!mosaic_gpu.barrier")),
625621
llvm.UndefOp(workgroup_ptr_ty()),
626622
arrival_count=arrival_count,
@@ -630,7 +626,7 @@ def test_initialize_barrier_op_lowering_rule(self):
630626
memref.copy(barriers_ref, barriers_ref)
631627

632628
self.assertTrue(self.module.operation.verify())
633-
lower_mgpu_dialect(self.module)
629+
mgpu.lower_mgpu_dialect(self.module)
634630
self.assertTrue(self.module.operation.verify())
635631

636632
all_mbarrier_init_shared_ops = find_if(
@@ -658,7 +654,7 @@ def test_lowering_vector_op_without_layout_fails(self):
658654
with self.assertRaisesRegex(
659655
ValueError, "missing a layout and can not be lowered"
660656
):
661-
lower_mgpu_dialect(self.module)
657+
mgpu.lower_mgpu_dialect(self.module)
662658

663659
def test_lowering_eliminates_layouts(self):
664660
shape = (4, 128)
@@ -669,10 +665,10 @@ def test_lowering_eliminates_layouts(self):
669665
ty = ir.VectorType.get(shape, elt_ty)
670666
load = vector.load(ty, ref, [zero_index, zero_index])
671667
load.owner.attributes["out_layouts"] = ir.ArrayAttr.get(
672-
[strided_fragmented_layout()]
668+
[mgpu.strided_fragmented_layout()]
673669
)
674670

675-
lower_mgpu_dialect(self.module)
671+
mgpu.lower_mgpu_dialect(self.module)
676672

677673
all_ops_with_layouts = find_if(
678674
self.module,
@@ -692,8 +688,8 @@ def test_lowering_vector_load_and_store_ops(self):
692688
array = vector.load(ty, ref, [zero_index, zero_index])
693689
vector.store(array, ref, [zero_index, zero_index])
694690

695-
infer_layout(self.module)
696-
lower_mgpu_dialect(self.module)
691+
mgpu.infer_layout(self.module)
692+
mgpu.lower_mgpu_dialect(self.module)
697693

698694
all_loads = find_if(
699695
self.module,

0 commit comments

Comments
 (0)