1
+ # RUN: %PYTHON %s | FileCheck %s
2
+
3
+ from mlir .ir import *
4
+ from mlir .dialects import shard
5
+ from mlir .dialects import func
6
+
7
+
8
+ def constructAndPrintInModule (f ):
9
+ print ("\n TEST:" , f .__name__ )
10
+ with Context (), Location .unknown ():
11
+ module = Module .create ()
12
+ with InsertionPoint (module .body ):
13
+ f ()
14
+ print (module )
15
+ return f
16
+
17
+
18
+ # CHECK-LABEL: TEST: testShardGrid
19
+ @constructAndPrintInModule
20
+ def testShardGrid ():
21
+ # Test creating shard grids with different shapes
22
+ grid2d = shard .GridOp ("grid_2d" , [2 , 2 ])
23
+ grid1d = shard .GridOp ("grid_1d" , [4 ])
24
+ grid_dynamic = shard .GridOp ("grid_dynamic" , [2 , - 1 ]) # -1 for dynamic dimension
25
+
26
+ # CHECK: shard.grid @grid_2d(shape = 2x2)
27
+ # CHECK: shard.grid @grid_1d(shape = 4)
28
+ # CHECK: shard.grid @grid_dynamic(shape = 2x?)
29
+
30
+
31
+ # CHECK-LABEL: TEST: testCollectiveOperations
32
+ @constructAndPrintInModule
33
+ def testCollectiveOperations ():
34
+ # Create grid and types
35
+ grid = shard .GridOp ("grid_2x2" , [2 , 2 ])
36
+ i32 = IntegerType .get_signless (32 )
37
+ input_type = RankedTensorType .get ([4 , 2 ], i32 )
38
+ gather_result_type = RankedTensorType .get ([4 , 4 ], i32 )
39
+
40
+ # Create a function to hold the operations
41
+ func_type = FunctionType .get ([input_type ], [input_type ])
42
+ test_func = func .FuncOp ("test_collectives" , func_type )
43
+
44
+ with InsertionPoint (test_func .add_entry_block ()):
45
+ arg = test_func .entry_block .arguments [0 ]
46
+
47
+ gather_op = shard .AllGatherOp (
48
+ input = arg ,
49
+ grid = FlatSymbolRefAttr .get ("grid_2x2" ),
50
+ grid_axes = ArrayAttr .get ([IntegerAttr .get (i32 , 1 )]),
51
+ gather_axis = IntegerAttr .get (i32 , 1 ),
52
+ result = gather_result_type ,
53
+ )
54
+
55
+ reduce_op = shard .AllReduceOp (
56
+ input = arg ,
57
+ grid = FlatSymbolRefAttr .get ("grid_2x2" ),
58
+ reduction = shard .ReductionKind .Sum ,
59
+ result = input_type ,
60
+ )
61
+
62
+ func .ReturnOp ([reduce_op ])
63
+
64
+ # CHECK: shard.grid @grid_2x2(shape = 2x2)
65
+ # CHECK: func @test_collectives(%{{.*}}: tensor<4x2xi32>) -> tensor<4x2xi32>
66
+ # CHECK: %{{.*}} = shard.all_gather %{{.*}} on @grid_2x2 grid_axes = [1] gather_axis = 1 : tensor<4x2xi32> -> tensor<4x4xi32>
67
+ # CHECK: %{{.*}} = shard.all_reduce %{{.*}} on @grid_2x2 reduction = sum : tensor<4x2xi32> -> tensor<4x2xi32>
0 commit comments