|
| 1 | +# RUN: %PYTHON %s | FileCheck %s |
| 2 | + |
| 3 | +from mlir.ir import * |
| 4 | +from mlir.dialects import shard |
| 5 | +from mlir.dialects import func |
| 6 | + |
| 7 | + |
| 8 | +def constructAndPrintInModule(f): |
| 9 | + print("\nTEST:", f.__name__) |
| 10 | + with Context(), Location.unknown(): |
| 11 | + module = Module.create() |
| 12 | + with InsertionPoint(module.body): |
| 13 | + f() |
| 14 | + print(module) |
| 15 | + return f |
| 16 | + |
| 17 | + |
| 18 | +# CHECK-LABEL: TEST: testShardGrid |
| 19 | +@constructAndPrintInModule |
| 20 | +def testShardGrid(): |
| 21 | + # Test creating shard grids with different shapes |
| 22 | + grid2d = shard.GridOp("grid_2d", [2, 2]) |
| 23 | + grid1d = shard.GridOp("grid_1d", [4]) |
| 24 | + grid_dynamic = shard.GridOp("grid_dynamic", [2, -1]) # -1 for dynamic dimension |
| 25 | + |
| 26 | + # CHECK: shard.grid @grid_2d(shape = 2x2) |
| 27 | + # CHECK: shard.grid @grid_1d(shape = 4) |
| 28 | + # CHECK: shard.grid @grid_dynamic(shape = 2x?) |
| 29 | + |
| 30 | + |
| 31 | +# CHECK-LABEL: TEST: testCollectiveOperations |
| 32 | +@constructAndPrintInModule |
| 33 | +def testCollectiveOperations(): |
| 34 | + # Create grid and types |
| 35 | + grid = shard.GridOp("grid_2x2", [2, 2]) |
| 36 | + i32 = IntegerType.get_signless(32) |
| 37 | + input_type = RankedTensorType.get([4, 2], i32) |
| 38 | + gather_result_type = RankedTensorType.get([4, 4], i32) |
| 39 | + |
| 40 | + # Create a function to hold the operations |
| 41 | + func_type = FunctionType.get([input_type], [input_type]) |
| 42 | + test_func = func.FuncOp("test_collectives", func_type) |
| 43 | + |
| 44 | + with InsertionPoint(test_func.add_entry_block()): |
| 45 | + arg = test_func.entry_block.arguments[0] |
| 46 | + |
| 47 | + gather_op = shard.AllGatherOp( |
| 48 | + input=arg, |
| 49 | + grid=FlatSymbolRefAttr.get("grid_2x2"), |
| 50 | + grid_axes=ArrayAttr.get([IntegerAttr.get(i32, 1)]), |
| 51 | + gather_axis=IntegerAttr.get(i32, 1), |
| 52 | + result=gather_result_type, |
| 53 | + ) |
| 54 | + |
| 55 | + reduce_op = shard.AllReduceOp( |
| 56 | + input=arg, |
| 57 | + grid=FlatSymbolRefAttr.get("grid_2x2"), |
| 58 | + reduction=shard.ReductionKind.Sum, |
| 59 | + result=input_type, |
| 60 | + ) |
| 61 | + |
| 62 | + func.ReturnOp([reduce_op]) |
| 63 | + |
| 64 | + # CHECK: shard.grid @grid_2x2(shape = 2x2) |
| 65 | + # CHECK: func @test_collectives(%{{.*}}: tensor<4x2xi32>) -> tensor<4x2xi32> |
| 66 | + # CHECK: %{{.*}} = shard.all_gather %{{.*}} on @grid_2x2 grid_axes = [1] gather_axis = 1 : tensor<4x2xi32> -> tensor<4x4xi32> |
| 67 | + # CHECK: %{{.*}} = shard.all_reduce %{{.*}} on @grid_2x2 reduction = sum : tensor<4x2xi32> -> tensor<4x2xi32> |
0 commit comments