Skip to content

Commit 2ea8c3e

Browse files
author
Siavash Nazari
committed
Add shard Dialect Python Bindings
1 parent 37af81f commit 2ea8c3e

File tree

5 files changed

+153
-0
lines changed

5 files changed

+153
-0
lines changed

mlir/python/CMakeLists.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,15 @@ declare_mlir_dialect_python_bindings(
336336
dialects/memref.py
337337
DIALECT_NAME memref)
338338

339+
declare_mlir_dialect_python_bindings(
340+
ADD_TO_PARENT MLIRPythonSources.Dialects
341+
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
342+
TD_FILE dialects/ShardOps.td
343+
SOURCES
344+
dialects/shard.py
345+
DIALECT_NAME shard
346+
GEN_ENUM_BINDINGS)
347+
339348
declare_mlir_dialect_python_bindings(
340349
ADD_TO_PARENT MLIRPythonSources.Dialects
341350
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
//===-- ShardOps.td - Entry point for ShardOps bindings ---------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef PYTHON_BINDINGS_SHARD_OPS
10+
#define PYTHON_BINDINGS_SHARD_OPS
11+
12+
include "mlir/Dialect/Shard/IR/ShardOps.td"
13+
14+
#endif

mlir/python/mlir/dialects/shard.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
5+
from ._shard_ops_gen import *
6+
from ._shard_ops_gen import _Dialect
7+
from ._shard_enum_gen import *
8+
9+
try:
10+
from ..ir import *
11+
from ._ods_common import (
12+
get_default_loc_context as _get_default_loc_context,
13+
_cext as _ods_cext,
14+
get_op_result_or_op_results as _get_op_result_or_op_results,
15+
)
16+
17+
from typing import Any, List, Union
18+
except ImportError as e:
19+
raise RuntimeError("Error loading imports from extension module") from e
20+
21+
22+
# The shard dialect currently doesn't need custom Python implementations for its operations
23+
# like the arith dialect does for ConstantOp. Most operations can use the generated bindings.
24+
# If specialized Python methods are needed for specific operations in the future,
25+
# they can be added here using the @_ods_cext.register_operation decorator pattern.

mlir/test/python/dialects/shard.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# RUN: %PYTHON %s | FileCheck %s
2+
3+
from mlir.ir import *
4+
from mlir.dialects import shard
5+
from mlir.dialects import func
6+
7+
8+
def constructAndPrintInModule(f):
9+
print("\nTEST:", f.__name__)
10+
with Context(), Location.unknown():
11+
module = Module.create()
12+
with InsertionPoint(module.body):
13+
f()
14+
print(module)
15+
return f
16+
17+
18+
# CHECK-LABEL: TEST: testShardGrid
19+
@constructAndPrintInModule
20+
def testShardGrid():
21+
# Test creating shard grids with different shapes
22+
grid2d = shard.GridOp("grid_2d", [2, 2])
23+
grid1d = shard.GridOp("grid_1d", [4])
24+
grid_dynamic = shard.GridOp("grid_dynamic", [2, -1]) # -1 for dynamic dimension
25+
26+
27+
# CHECK: shard.grid @grid_2d(shape = 2x2)
28+
# CHECK: shard.grid @grid_1d(shape = 4)
29+
# CHECK: shard.grid @grid_dynamic(shape = 2x?)
30+
31+
32+
# CHECK-LABEL: TEST: testCollectiveOperations
33+
@constructAndPrintInModule
34+
def testCollectiveOperations():
35+
# Create grid and types
36+
grid = shard.GridOp("grid_2x2", [2, 2])
37+
i32 = IntegerType.get_signless(32)
38+
input_type = RankedTensorType.get([4, 2], i32)
39+
gather_result_type = RankedTensorType.get([4, 4], i32)
40+
41+
# Create a function to hold the operations
42+
func_type = FunctionType.get([input_type], [input_type])
43+
test_func = func.FuncOp("test_collectives", func_type)
44+
45+
with InsertionPoint(test_func.add_entry_block()):
46+
arg = test_func.entry_block.arguments[0]
47+
48+
# All-gather operation
49+
gather_op = shard.AllGatherOp(
50+
input=arg,
51+
grid=FlatSymbolRefAttr.get("grid_2x2"),
52+
grid_axes=ArrayAttr.get([IntegerAttr.get(i32, 1)]),
53+
gather_axis=IntegerAttr.get(i32, 1),
54+
result=gather_result_type
55+
)
56+
57+
# All-reduce operation (ReductionKind might need different construction)
58+
reduce_op = shard.AllReduceOp(
59+
input=arg,
60+
grid=FlatSymbolRefAttr.get("grid_2x2"),
61+
reduction=IntegerAttr.get(IntegerType.get_signless(32), 1), # 1 = sum from enum
62+
result=input_type
63+
)
64+
65+
# Return the reduced result
66+
func.ReturnOp([reduce_op])
67+
68+
69+
# CHECK: shard.grid @grid_2x2(shape = 2x2)
70+
# CHECK: func @test_collectives(%{{.*}}: tensor<4x2xi32>) -> tensor<4x2xi32>
71+
# CHECK: %{{.*}} = shard.all_gather %{{.*}} on @grid_2x2 grid_axes = [1] gather_axis = 1 : tensor<4x2xi32> -> tensor<4x4xi32>
72+
# CHECK: %{{.*}} = shard.all_reduce %{{.*}} on @grid_2x2 reduction = sum : tensor<4x2xi32> -> tensor<4x2xi32>
73+

utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -981,6 +981,38 @@ filegroup(
981981
],
982982
)
983983

984+
##---------------------------------------------------------------------------##
985+
# Shard dialect.
986+
##---------------------------------------------------------------------------##
987+
988+
gentbl_filegroup(
989+
name = "ShardOpsPyGen",
990+
tbl_outs = {
991+
"mlir/dialects/_shard_enum_gen.py": [
992+
"-gen-python-enum-bindings",
993+
"-bind-dialect=shard",
994+
],
995+
"mlir/dialects/_shard_ops_gen.py": [
996+
"-gen-python-op-bindings",
997+
"-bind-dialect=shard",
998+
],
999+
},
1000+
tblgen = "//mlir:mlir-tblgen",
1001+
td_file = "mlir/dialects/ShardOps.td",
1002+
deps = [
1003+
"//mlir:OpBaseTdFiles",
1004+
"//mlir:ShardTdFiles",
1005+
],
1006+
)
1007+
1008+
filegroup(
1009+
name = "ShardOpsPyFiles",
1010+
srcs = [
1011+
"mlir/dialects/shard.py",
1012+
":ShardOpsPyGen",
1013+
],
1014+
)
1015+
9841016
##---------------------------------------------------------------------------##
9851017
# Shape dialect.
9861018
##---------------------------------------------------------------------------##

0 commit comments

Comments
 (0)