@@ -74,7 +74,8 @@ def is_mosaic_gpu_op(op: ir.OpView) -> bool:
7474
7575def workgroup_ptr_ty () -> ir .Type :
7676 workgroup_nvptx_address_space = gpu_address_space_to_nvptx (
77- gpu .AddressSpace .Workgroup )
77+ gpu .AddressSpace .Workgroup
78+ )
7879 return ir .Type .parse (f"!llvm.ptr<{ workgroup_nvptx_address_space } >" )
7980
8081
@@ -95,7 +96,9 @@ def test_initialize_barrier_op_result_memref_must_wrap_barriers(self):
9596 with ir .InsertionPoint (self .module .body ):
9697 mgpu .initialize_barrier (
9798 ir .MemRefType .get ((1 , 2 ), ir .F32Type .get ()),
98- llvm .UndefOp (workgroup_ptr_ty ()), arrival_count = 1 )
99+ llvm .UndefOp (workgroup_ptr_ty ()),
100+ arrival_count = 1 ,
101+ )
99102 with self .assertRaisesRegex (
100103 ir .MLIRError , "must be memref of barrier values"
101104 ):
@@ -106,7 +109,8 @@ def test_initialize_barrier_op_arrival_count_must_be_strictly_positive(self):
106109 mgpu .initialize_barrier (
107110 ir .MemRefType .get ((1 , 2 ), ir .Type .parse ("!mosaic_gpu.barrier" )),
108111 llvm .UndefOp (workgroup_ptr_ty ()),
109- arrival_count = 0 )
112+ arrival_count = 0 ,
113+ )
110114 with self .assertRaisesRegex (ir .MLIRError , "value is positive" ):
111115 self .module .operation .verify ()
112116
@@ -115,7 +119,8 @@ def test_initialize_barrier_op_with_a_non_shared_base_pointer_fails(self):
115119 mgpu .initialize_barrier (
116120 ir .MemRefType .get ((1 , 2 ), ir .Type .parse ("!mosaic_gpu.barrier" )),
117121 llvm .UndefOp (ir .Type .parse (f"!llvm.ptr<{ 0 } >" )),
118- arrival_count = 1 )
122+ arrival_count = 1 ,
123+ )
119124 with self .assertRaisesRegex (ir .MLIRError , "pointer in address space 3" ):
120125 self .module .operation .verify ()
121126
@@ -124,10 +129,12 @@ def test_initialize_barrier_op_with_a_positive_arrival_count_passes(self):
124129 mgpu .initialize_barrier (
125130 ir .MemRefType .get ((1 , 2 ), ir .Type .parse ("!mosaic_gpu.barrier" )),
126131 llvm .UndefOp (workgroup_ptr_ty ()),
127- arrival_count = 1 )
132+ arrival_count = 1 ,
133+ )
128134 self .assertTrue (self .module .operation .verify ())
129- self .assertIsInstance (self .module .body .operations [1 ],
130- mgpu .InitializeBarrierOp )
135+ self .assertIsInstance (
136+ self .module .body .operations [1 ], mgpu .InitializeBarrierOp
137+ )
131138
132139 def test_async_load_op_dest_must_be_contiguous (self ):
133140 with ir .InsertionPoint (self .module .body ):
@@ -575,7 +582,8 @@ def test_lowering_removes_mosaic_gpu_ops(self):
575582 mgpu .initialize_barrier (
576583 ir .MemRefType .get ((1 , 2 ), ir .Type .parse ("!mosaic_gpu.barrier" )),
577584 llvm .UndefOp (workgroup_ptr_ty ()),
578- arrival_count = 1 )
585+ arrival_count = 1 ,
586+ )
579587 lower_mgpu_dialect (self .module )
580588
581589 self .assertEmpty (
@@ -591,7 +599,8 @@ def test_lowering_traverses_regions_correctly(self):
591599 mgpu .initialize_barrier (
592600 ir .MemRefType .get ((1 , 2 ), ir .Type .parse ("!mosaic_gpu.barrier" )),
593601 llvm .UndefOp (workgroup_ptr_ty ()),
594- arrival_count = 1 )
602+ arrival_count = 1 ,
603+ )
595604 scf .yield_ ([])
596605 lower_mgpu_dialect (self .module )
597606
@@ -608,7 +617,8 @@ def test_initialize_barrier_op_lowering_rule(self):
608617 barriers_ref = mgpu .initialize_barrier (
609618 ir .MemRefType .get (shape , ir .Type .parse ("!mosaic_gpu.barrier" )),
610619 llvm .UndefOp (workgroup_ptr_ty ()),
611- arrival_count = arrival_count )
620+ arrival_count = arrival_count ,
621+ )
612622 # Add a user for barriers_ref to make sure that the lowering keeps types
613623 # consistent.
614624 memref .copy (barriers_ref , barriers_ref )
0 commit comments