Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions mlir/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,15 @@ declare_mlir_dialect_python_bindings(
dialects/memref.py
DIALECT_NAME memref)

declare_mlir_dialect_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/ShardOps.td
SOURCES
dialects/shard.py
DIALECT_NAME shard
GEN_ENUM_BINDINGS)

declare_mlir_dialect_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
Expand Down
14 changes: 14 additions & 0 deletions mlir/python/mlir/dialects/ShardOps.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
//===-- ShardOps.td - Entry point for ShardOps bindings ---------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef PYTHON_BINDINGS_SHARD_OPS
#define PYTHON_BINDINGS_SHARD_OPS

include "mlir/Dialect/Shard/IR/ShardOps.td"

#endif
6 changes: 6 additions & 0 deletions mlir/python/mlir/dialects/shard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from ._shard_ops_gen import *
from ._shard_enum_gen import *
67 changes: 67 additions & 0 deletions mlir/test/python/dialects/shard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# RUN: %PYTHON %s | FileCheck %s

from mlir.ir import *
from mlir.dialects import shard
from mlir.dialects import func


def constructAndPrintInModule(f):
print("\nTEST:", f.__name__)
with Context(), Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
f()
print(module)
module.operation.verify()
return f


# CHECK-LABEL: TEST: testShardGrid
@constructAndPrintInModule
def testShardGrid():
# Test creating shard grids with different shapes
grid2d = shard.GridOp("grid_2d", [2, 2])
grid1d = shard.GridOp("grid_1d", [4])

# CHECK: shard.grid @grid_2d(shape = 2x2)
# CHECK: shard.grid @grid_1d(shape = 4)


# CHECK-LABEL: TEST: testCollectiveOperations
@constructAndPrintInModule
def testCollectiveOperations():
# Create grid and types
grid_op = shard.GridOp("grid_2x2", [2, 2])
i32 = IntegerType.get_signless(32)
index_type = IndexType.get()
input_type = RankedTensorType.get([4, 2], i32)
gather_result_type = RankedTensorType.get([4, 4], i32)

# Create a function to hold the operations
func_type = FunctionType.get([input_type], [input_type])
test_func = func.FuncOp("test_collectives", func_type)

with InsertionPoint(test_func.add_entry_block()):
arg = test_func.entry_block.arguments[0]

gather_op = shard.AllGatherOp(
input=arg,
grid=FlatSymbolRefAttr.get("grid_2x2"),
grid_axes=DenseI16ArrayAttr.get([1]),
gather_axis=IntegerAttr.get(index_type, 1),
result=gather_result_type,
)

reduce_op = shard.AllReduceOp(
input=arg,
grid=FlatSymbolRefAttr.get("grid_2x2"),
reduction=shard.ReductionKind.Sum,
result=input_type,
)

func.ReturnOp([reduce_op])

# CHECK: shard.grid @grid_2x2(shape = 2x2)
# CHECK: func.func @test_collectives(%arg0: tensor<4x2xi32>) -> tensor<4x2xi32>
# CHECK: %all_gather = shard.all_gather %arg0 on @grid_2x2 grid_axes = [1] gather_axis = 1 : tensor<4x2xi32> -> tensor<4x4xi32>
# CHECK: %all_reduce = shard.all_reduce %arg0 on @grid_2x2 : tensor<4x2xi32> -> tensor<4x2xi32>
32 changes: 32 additions & 0 deletions utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -981,6 +981,38 @@ filegroup(
],
)

##---------------------------------------------------------------------------##
# Shard dialect.
##---------------------------------------------------------------------------##

gentbl_filegroup(
name = "ShardOpsPyGen",
tbl_outs = {
"mlir/dialects/_shard_enum_gen.py": [
"-gen-python-enum-bindings",
"-bind-dialect=shard",
],
"mlir/dialects/_shard_ops_gen.py": [
"-gen-python-op-bindings",
"-bind-dialect=shard",
],
},
tblgen = "//mlir:mlir-tblgen",
td_file = "mlir/dialects/ShardOps.td",
deps = [
"//mlir:OpBaseTdFiles",
"//mlir:ShardTdFiles",
],
)

filegroup(
name = "ShardOpsPyFiles",
srcs = [
"mlir/dialects/shard.py",
":ShardOpsPyGen",
],
)

##---------------------------------------------------------------------------##
# Shape dialect.
##---------------------------------------------------------------------------##
Expand Down
Loading