Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions mlir/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,15 @@ declare_mlir_dialect_python_bindings(
dialects/memref.py
DIALECT_NAME memref)

declare_mlir_dialect_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/ShardOps.td
SOURCES
dialects/shard.py
DIALECT_NAME shard
GEN_ENUM_BINDINGS)

declare_mlir_dialect_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
Expand Down
14 changes: 14 additions & 0 deletions mlir/python/mlir/dialects/ShardOps.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
//===-- ShardOps.td - Entry point for ShardOps bindings ---------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef PYTHON_BINDINGS_SHARD_OPS
#define PYTHON_BINDINGS_SHARD_OPS

include "mlir/Dialect/Shard/IR/ShardOps.td"

#endif
6 changes: 6 additions & 0 deletions mlir/python/mlir/dialects/shard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from ._shard_ops_gen import *
from ._shard_enum_gen import *
67 changes: 67 additions & 0 deletions mlir/test/python/dialects/shard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# RUN: %PYTHON %s | FileCheck %s

from mlir.ir import *
from mlir.dialects import shard
from mlir.dialects import func


def constructAndPrintInModule(f):
print("\nTEST:", f.__name__)
with Context(), Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
f()
print(module)
return f


# CHECK-LABEL: TEST: testShardGrid
@constructAndPrintInModule
def testShardGrid():
# Test creating shard grids with different shapes
grid2d = shard.GridOp("grid_2d", [2, 2])
grid1d = shard.GridOp("grid_1d", [4])
grid_dynamic = shard.GridOp("grid_dynamic", [2, -1]) # -1 for dynamic dimension

# CHECK: "shard.grid"() <{shape = array<i64: 2, 2>, sym_name = "grid_2d"}> : () -> ()
# CHECK: "shard.grid"() <{shape = array<i64: 4>, sym_name = "grid_1d"}> : () -> ()
# CHECK: "shard.grid"() <{shape = array<i64: 2, -1>, sym_name = "grid_dynamic"}> : () -> ()


# CHECK-LABEL: TEST: testCollectiveOperations
@constructAndPrintInModule
def testCollectiveOperations():
# Create grid and types
grid = shard.GridOp("grid_2x2", [2, 2])
i32 = IntegerType.get_signless(32)
input_type = RankedTensorType.get([4, 2], i32)
gather_result_type = RankedTensorType.get([4, 4], i32)

# Create a function to hold the operations
func_type = FunctionType.get([input_type], [input_type])
test_func = func.FuncOp("test_collectives", func_type)

with InsertionPoint(test_func.add_entry_block()):
arg = test_func.entry_block.arguments[0]

gather_op = shard.AllGatherOp(
input=arg,
grid=FlatSymbolRefAttr.get("grid_2x2"),
grid_axes=ArrayAttr.get([IntegerAttr.get(i32, 1)]),
gather_axis=IntegerAttr.get(i32, 1),
result=gather_result_type,
)

reduce_op = shard.AllReduceOp(
input=arg,
grid=FlatSymbolRefAttr.get("grid_2x2"),
reduction=shard.ReductionKind.Sum,
result=input_type,
)

func.ReturnOp([reduce_op])

# CHECK: "shard.grid"() <{shape = array<i64: 2, 2>, sym_name = "grid_2x2"}> : () -> ()
# CHECK: "func.func"() <{function_type = (tensor<4x2xi32>) -> tensor<4x2xi32>, sym_name = "test_collectives"}>
# CHECK: "shard.all_gather"({{.*}}) <{gather_axis = 1 : i32, grid = @grid_2x2}> : (tensor<4x2xi32>) -> tensor<4x4xi32>
# CHECK: "shard.all_reduce"({{.*}}) <{grid = @grid_2x2, {{.*}} reduction = #shard<partial sum>}> : (tensor<4x2xi32>) -> tensor<4x2xi32>
32 changes: 32 additions & 0 deletions utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -981,6 +981,38 @@ filegroup(
],
)

##---------------------------------------------------------------------------##
# Shard dialect.
##---------------------------------------------------------------------------##

gentbl_filegroup(
name = "ShardOpsPyGen",
tbl_outs = {
"mlir/dialects/_shard_enum_gen.py": [
"-gen-python-enum-bindings",
"-bind-dialect=shard",
],
"mlir/dialects/_shard_ops_gen.py": [
"-gen-python-op-bindings",
"-bind-dialect=shard",
],
},
tblgen = "//mlir:mlir-tblgen",
td_file = "mlir/dialects/ShardOps.td",
deps = [
"//mlir:OpBaseTdFiles",
"//mlir:ShardTdFiles",
],
)

filegroup(
name = "ShardOpsPyFiles",
srcs = [
"mlir/dialects/shard.py",
":ShardOpsPyGen",
],
)
Comment on lines +988 to +1014
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unless you work at Google (or you really actually do use bazel) you don't need to do this (bazel build is maintained by the users that actually use bazel)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing this out! I'm actually interested in building my project with Bazel, so let's keep it

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok sure np just gotta get someone that knows bazel (not me lol) to sign off. @jpienaar @superbobry this look right to you?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @makslevental! Seems like Bazel build is fine. Can I get reviews please? @jpienaar @superbobry


##---------------------------------------------------------------------------##
# Shape dialect.
##---------------------------------------------------------------------------##
Expand Down