|
| 1 | +# RUN: %PYTHON %s | FileCheck %s |
| 2 | + |
| 3 | +from mlir.ir import * |
| 4 | +from mlir.dialects import shard |
| 5 | +from mlir.dialects import func |
| 6 | + |
| 7 | + |
| 8 | +def constructAndPrintInModule(f): |
| 9 | + print("\nTEST:", f.__name__) |
| 10 | + with Context(), Location.unknown(): |
| 11 | + module = Module.create() |
| 12 | + with InsertionPoint(module.body): |
| 13 | + f() |
| 14 | + print(module) |
| 15 | + return f |
| 16 | + |
| 17 | + |
| 18 | +# CHECK-LABEL: TEST: testShardGrid |
| 19 | +@constructAndPrintInModule |
| 20 | +def testShardGrid(): |
| 21 | + # Test creating shard grids with different shapes |
| 22 | + grid2d = shard.GridOp("grid_2d", [2, 2]) |
| 23 | + grid1d = shard.GridOp("grid_1d", [4]) |
| 24 | + grid_dynamic = shard.GridOp("grid_dynamic", [2, -1]) # -1 for dynamic dimension |
| 25 | + |
| 26 | + |
| 27 | +# CHECK: shard.grid @grid_2d(shape = 2x2) |
| 28 | +# CHECK: shard.grid @grid_1d(shape = 4) |
| 29 | +# CHECK: shard.grid @grid_dynamic(shape = 2x?) |
| 30 | + |
| 31 | + |
| 32 | +# CHECK-LABEL: TEST: testCollectiveOperations |
| 33 | +@constructAndPrintInModule |
| 34 | +def testCollectiveOperations(): |
| 35 | + # Create grid and types |
| 36 | + grid = shard.GridOp("grid_2x2", [2, 2]) |
| 37 | + i32 = IntegerType.get_signless(32) |
| 38 | + input_type = RankedTensorType.get([4, 2], i32) |
| 39 | + gather_result_type = RankedTensorType.get([4, 4], i32) |
| 40 | + |
| 41 | + # Create a function to hold the operations |
| 42 | + func_type = FunctionType.get([input_type], [input_type]) |
| 43 | + test_func = func.FuncOp("test_collectives", func_type) |
| 44 | + |
| 45 | + with InsertionPoint(test_func.add_entry_block()): |
| 46 | + arg = test_func.entry_block.arguments[0] |
| 47 | + |
| 48 | + # All-gather operation |
| 49 | + gather_op = shard.AllGatherOp( |
| 50 | + input=arg, |
| 51 | + grid=FlatSymbolRefAttr.get("grid_2x2"), |
| 52 | + grid_axes=ArrayAttr.get([IntegerAttr.get(i32, 1)]), |
| 53 | + gather_axis=IntegerAttr.get(i32, 1), |
| 54 | + result=gather_result_type |
| 55 | + ) |
| 56 | + |
| 57 | + # All-reduce operation (ReductionKind might need different construction) |
| 58 | + reduce_op = shard.AllReduceOp( |
| 59 | + input=arg, |
| 60 | + grid=FlatSymbolRefAttr.get("grid_2x2"), |
| 61 | + reduction=IntegerAttr.get(IntegerType.get_signless(32), 1), # 1 = sum from enum |
| 62 | + result=input_type |
| 63 | + ) |
| 64 | + |
| 65 | + # Return the reduced result |
| 66 | + func.ReturnOp([reduce_op]) |
| 67 | + |
| 68 | + |
| 69 | +# CHECK: shard.grid @grid_2x2(shape = 2x2) |
| 70 | +# CHECK: func @test_collectives(%{{.*}}: tensor<4x2xi32>) -> tensor<4x2xi32> |
| 71 | +# CHECK: %{{.*}} = shard.all_gather %{{.*}} on @grid_2x2 grid_axes = [1] gather_axis = 1 : tensor<4x2xi32> -> tensor<4x4xi32> |
| 72 | +# CHECK: %{{.*}} = shard.all_reduce %{{.*}} on @grid_2x2 reduction = sum : tensor<4x2xi32> -> tensor<4x2xi32> |
| 73 | + |
0 commit comments