|
| 1 | +# RUN: %PYTHON %s | FileCheck %s |
| 2 | + |
| 3 | +from mlir.ir import * |
| 4 | +from mlir.dialects import shard |
| 5 | +from mlir.dialects import func |
| 6 | + |
| 7 | + |
| 8 | +def constructAndPrintInModule(f): |
| 9 | + print("\nTEST:", f.__name__) |
| 10 | + with Context(), Location.unknown(): |
| 11 | + module = Module.create() |
| 12 | + with InsertionPoint(module.body): |
| 13 | + f() |
| 14 | + print(module) |
| 15 | + return f |
| 16 | + |
| 17 | + |
| 18 | +# CHECK-LABEL: TEST: testShardGrid |
| 19 | +@constructAndPrintInModule |
| 20 | +def testShardGrid(): |
| 21 | + # Test creating shard grids with different shapes |
| 22 | + grid2d = shard.GridOp("grid_2d", [2, 2]) |
| 23 | + grid1d = shard.GridOp("grid_1d", [4]) |
| 24 | + grid_dynamic = shard.GridOp("grid_dynamic", [2, -1]) # -1 for dynamic dimension |
| 25 | + |
| 26 | + # CHECK: "shard.grid"() <{shape = array<i64: 2, 2>, sym_name = "grid_2d"}> : () -> () |
| 27 | + # CHECK: "shard.grid"() <{shape = array<i64: 4>, sym_name = "grid_1d"}> : () -> () |
| 28 | + # CHECK: "shard.grid"() <{shape = array<i64: 2, -1>, sym_name = "grid_dynamic"}> : () -> () |
| 29 | + |
| 30 | + |
| 31 | +# CHECK-LABEL: TEST: testCollectiveOperations |
| 32 | +@constructAndPrintInModule |
| 33 | +def testCollectiveOperations(): |
| 34 | + # Create grid and types |
| 35 | + grid = shard.GridOp("grid_2x2", [2, 2]) |
| 36 | + i32 = IntegerType.get_signless(32) |
| 37 | + input_type = RankedTensorType.get([4, 2], i32) |
| 38 | + gather_result_type = RankedTensorType.get([4, 4], i32) |
| 39 | + |
| 40 | + # Create a function to hold the operations |
| 41 | + func_type = FunctionType.get([input_type], [input_type]) |
| 42 | + test_func = func.FuncOp("test_collectives", func_type) |
| 43 | + |
| 44 | + with InsertionPoint(test_func.add_entry_block()): |
| 45 | + arg = test_func.entry_block.arguments[0] |
| 46 | + |
| 47 | + gather_op = shard.AllGatherOp( |
| 48 | + input=arg, |
| 49 | + grid=FlatSymbolRefAttr.get("grid_2x2"), |
| 50 | + grid_axes=ArrayAttr.get([IntegerAttr.get(i32, 1)]), |
| 51 | + gather_axis=IntegerAttr.get(i32, 1), |
| 52 | + result=gather_result_type, |
| 53 | + ) |
| 54 | + |
| 55 | + reduce_op = shard.AllReduceOp( |
| 56 | + input=arg, |
| 57 | + grid=FlatSymbolRefAttr.get("grid_2x2"), |
| 58 | + reduction=shard.ReductionKind.Sum, |
| 59 | + result=input_type, |
| 60 | + ) |
| 61 | + |
| 62 | + func.ReturnOp([reduce_op]) |
| 63 | + |
| 64 | + # CHECK: "shard.grid"() <{shape = array<i64: 2, 2>, sym_name = "grid_2x2"}> : () -> () |
| 65 | + # CHECK: "func.func"() <{function_type = (tensor<4x2xi32>) -> tensor<4x2xi32>, sym_name = "test_collectives"}> |
| 66 | + # CHECK: "shard.all_gather"({{.*}}) <{gather_axis = 1 : i32, grid = @grid_2x2}> : (tensor<4x2xi32>) -> tensor<4x4xi32> |
| 67 | + # CHECK: "shard.all_reduce"({{.*}}) <{grid = @grid_2x2, {{.*}} reduction = #shard<partial sum>}> : (tensor<4x2xi32>) -> tensor<4x2xi32> |
0 commit comments