From e3d7af0a93149687b3ade569013ad07e2ac038c9 Mon Sep 17 00:00:00 2001 From: Siavash Nazari Date: Wed, 8 Oct 2025 22:07:33 -0700 Subject: [PATCH] [MLIR][Python] Add shard Dialect Python Bindings --- mlir/python/CMakeLists.txt | 9 +++ mlir/python/mlir/dialects/ShardOps.td | 14 ++++ mlir/python/mlir/dialects/shard.py | 6 ++ mlir/test/python/dialects/shard.py | 67 +++++++++++++++++++ .../mlir/python/BUILD.bazel | 32 +++++++++ 5 files changed, 128 insertions(+) create mode 100644 mlir/python/mlir/dialects/ShardOps.td create mode 100644 mlir/python/mlir/dialects/shard.py create mode 100644 mlir/test/python/dialects/shard.py diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 9f5246de6bda0..20f07440df2c3 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -336,6 +336,15 @@ declare_mlir_dialect_python_bindings( dialects/memref.py DIALECT_NAME memref) +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/ShardOps.td + SOURCES + dialects/shard.py + DIALECT_NAME shard + GEN_ENUM_BINDINGS) + declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" diff --git a/mlir/python/mlir/dialects/ShardOps.td b/mlir/python/mlir/dialects/ShardOps.td new file mode 100644 index 0000000000000..f8527664df67b --- /dev/null +++ b/mlir/python/mlir/dialects/ShardOps.td @@ -0,0 +1,14 @@ +//===-- ShardOps.td - Entry point for ShardOps bindings ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_SHARD_OPS +#define PYTHON_BINDINGS_SHARD_OPS + +include "mlir/Dialect/Shard/IR/ShardOps.td" + +#endif diff --git a/mlir/python/mlir/dialects/shard.py b/mlir/python/mlir/dialects/shard.py new file mode 100644 index 0000000000000..8d69f17954290 --- /dev/null +++ b/mlir/python/mlir/dialects/shard.py @@ -0,0 +1,6 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ._shard_ops_gen import * +from ._shard_enum_gen import * diff --git a/mlir/test/python/dialects/shard.py b/mlir/test/python/dialects/shard.py new file mode 100644 index 0000000000000..52a761eac0bfc --- /dev/null +++ b/mlir/test/python/dialects/shard.py @@ -0,0 +1,67 @@ +# RUN: %PYTHON %s | FileCheck %s + +from mlir.ir import * +from mlir.dialects import shard +from mlir.dialects import func + + +def constructAndPrintInModule(f): + print("\nTEST:", f.__name__) + with Context(), Location.unknown(): + module = Module.create() + with InsertionPoint(module.body): + f() + print(module) + return f + + +# CHECK-LABEL: TEST: testShardGrid +@constructAndPrintInModule +def testShardGrid(): + # Test creating shard grids with different shapes + grid2d = shard.GridOp("grid_2d", [2, 2]) + grid1d = shard.GridOp("grid_1d", [4]) + grid_dynamic = shard.GridOp("grid_dynamic", [2, -1]) # -1 for dynamic dimension + + # CHECK: "shard.grid"() <{shape = array, sym_name = "grid_2d"}> : () -> () + # CHECK: "shard.grid"() <{shape = array, sym_name = "grid_1d"}> : () -> () + # CHECK: "shard.grid"() <{shape = array, sym_name = "grid_dynamic"}> : () -> () + + +# CHECK-LABEL: TEST: testCollectiveOperations +@constructAndPrintInModule +def testCollectiveOperations(): + # Create grid and types + grid = shard.GridOp("grid_2x2", [2, 2]) + i32 = IntegerType.get_signless(32) + input_type = RankedTensorType.get([4, 2], i32) + gather_result_type = RankedTensorType.get([4, 4], i32) + + # Create a function to hold the operations + func_type = FunctionType.get([input_type], [input_type]) + test_func = func.FuncOp("test_collectives", func_type) + + with InsertionPoint(test_func.add_entry_block()): + arg = test_func.entry_block.arguments[0] + + gather_op = shard.AllGatherOp( + input=arg, + grid=FlatSymbolRefAttr.get("grid_2x2"), + grid_axes=ArrayAttr.get([IntegerAttr.get(i32, 1)]), + gather_axis=IntegerAttr.get(i32, 1), + result=gather_result_type, + ) + + reduce_op = shard.AllReduceOp( + input=arg, + grid=FlatSymbolRefAttr.get("grid_2x2"), + reduction=shard.ReductionKind.Sum, + result=input_type, + ) + + func.ReturnOp([reduce_op]) + + # CHECK: "shard.grid"() <{shape = array, sym_name = "grid_2x2"}> : () -> () + # CHECK: "func.func"() <{function_type = (tensor<4x2xi32>) -> tensor<4x2xi32>, sym_name = "test_collectives"}> + # CHECK: "shard.all_gather"({{.*}}) <{gather_axis = 1 : i32, grid = @grid_2x2}> : (tensor<4x2xi32>) -> tensor<4x4xi32> + # CHECK: "shard.all_reduce"({{.*}}) <{grid = @grid_2x2, {{.*}} reduction = #shard}> : (tensor<4x2xi32>) -> tensor<4x2xi32> diff --git a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel index 102c4161eb74c..72af4f08bde57 100644 --- a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel @@ -981,6 +981,38 @@ filegroup( ], ) +##---------------------------------------------------------------------------## +# Shard dialect. +##---------------------------------------------------------------------------## + +gentbl_filegroup( + name = "ShardOpsPyGen", + tbl_outs = { + "mlir/dialects/_shard_enum_gen.py": [ + "-gen-python-enum-bindings", + "-bind-dialect=shard", + ], + "mlir/dialects/_shard_ops_gen.py": [ + "-gen-python-op-bindings", + "-bind-dialect=shard", + ], + }, + tblgen = "//mlir:mlir-tblgen", + td_file = "mlir/dialects/ShardOps.td", + deps = [ + "//mlir:OpBaseTdFiles", + "//mlir:ShardTdFiles", + ], +) + +filegroup( + name = "ShardOpsPyFiles", + srcs = [ + "mlir/dialects/shard.py", + ":ShardOpsPyGen", + ], +) + ##---------------------------------------------------------------------------## # Shape dialect. ##---------------------------------------------------------------------------##