Skip to content

Commit 3d9c720

Browse files
[Mosaic GPU] Automatically format the Mosaic GPU dialect test python code
This allows me to keep using the formatter going forward and not have to bother manually formatting code. PiperOrigin-RevId: 705024602
1 parent 66f45d0 commit 3d9c720

File tree

1 file changed

+20
-10
lines changed

1 file changed

+20
-10
lines changed

tests/mosaic/gpu_dialect_test.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,8 @@ def is_mosaic_gpu_op(op: ir.OpView) -> bool:
7474

7575
def workgroup_ptr_ty() -> ir.Type:
7676
workgroup_nvptx_address_space = gpu_address_space_to_nvptx(
77-
gpu.AddressSpace.Workgroup)
77+
gpu.AddressSpace.Workgroup
78+
)
7879
return ir.Type.parse(f"!llvm.ptr<{workgroup_nvptx_address_space}>")
7980

8081

@@ -95,7 +96,9 @@ def test_initialize_barrier_op_result_memref_must_wrap_barriers(self):
9596
with ir.InsertionPoint(self.module.body):
9697
mgpu.initialize_barrier(
9798
ir.MemRefType.get((1, 2), ir.F32Type.get()),
98-
llvm.UndefOp(workgroup_ptr_ty()), arrival_count=1)
99+
llvm.UndefOp(workgroup_ptr_ty()),
100+
arrival_count=1,
101+
)
99102
with self.assertRaisesRegex(
100103
ir.MLIRError, "must be memref of barrier values"
101104
):
@@ -106,7 +109,8 @@ def test_initialize_barrier_op_arrival_count_must_be_strictly_positive(self):
106109
mgpu.initialize_barrier(
107110
ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")),
108111
llvm.UndefOp(workgroup_ptr_ty()),
109-
arrival_count=0)
112+
arrival_count=0,
113+
)
110114
with self.assertRaisesRegex(ir.MLIRError, "value is positive"):
111115
self.module.operation.verify()
112116

@@ -115,7 +119,8 @@ def test_initialize_barrier_op_with_a_non_shared_base_pointer_fails(self):
115119
mgpu.initialize_barrier(
116120
ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")),
117121
llvm.UndefOp(ir.Type.parse(f"!llvm.ptr<{0}>")),
118-
arrival_count=1)
122+
arrival_count=1,
123+
)
119124
with self.assertRaisesRegex(ir.MLIRError, "pointer in address space 3"):
120125
self.module.operation.verify()
121126

@@ -124,10 +129,12 @@ def test_initialize_barrier_op_with_a_positive_arrival_count_passes(self):
124129
mgpu.initialize_barrier(
125130
ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")),
126131
llvm.UndefOp(workgroup_ptr_ty()),
127-
arrival_count=1)
132+
arrival_count=1,
133+
)
128134
self.assertTrue(self.module.operation.verify())
129-
self.assertIsInstance(self.module.body.operations[1],
130-
mgpu.InitializeBarrierOp)
135+
self.assertIsInstance(
136+
self.module.body.operations[1], mgpu.InitializeBarrierOp
137+
)
131138

132139
def test_async_load_op_dest_must_be_contiguous(self):
133140
with ir.InsertionPoint(self.module.body):
@@ -575,7 +582,8 @@ def test_lowering_removes_mosaic_gpu_ops(self):
575582
mgpu.initialize_barrier(
576583
ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")),
577584
llvm.UndefOp(workgroup_ptr_ty()),
578-
arrival_count=1)
585+
arrival_count=1,
586+
)
579587
lower_mgpu_dialect(self.module)
580588

581589
self.assertEmpty(
@@ -591,7 +599,8 @@ def test_lowering_traverses_regions_correctly(self):
591599
mgpu.initialize_barrier(
592600
ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")),
593601
llvm.UndefOp(workgroup_ptr_ty()),
594-
arrival_count=1)
602+
arrival_count=1,
603+
)
595604
scf.yield_([])
596605
lower_mgpu_dialect(self.module)
597606

@@ -608,7 +617,8 @@ def test_initialize_barrier_op_lowering_rule(self):
608617
barriers_ref = mgpu.initialize_barrier(
609618
ir.MemRefType.get(shape, ir.Type.parse("!mosaic_gpu.barrier")),
610619
llvm.UndefOp(workgroup_ptr_ty()),
611-
arrival_count=arrival_count)
620+
arrival_count=arrival_count,
621+
)
612622
# Add a user for barriers_ref to make sure that the lowering keeps types
613623
# consistent.
614624
memref.copy(barriers_ref, barriers_ref)

0 commit comments

Comments
 (0)