-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[MLIR][Python] Add shard Dialect Python Bindings #162578
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
//===-- ShardOps.td - Entry point for ShardOps bindings ---------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#ifndef PYTHON_BINDINGS_SHARD_OPS | ||
#define PYTHON_BINDINGS_SHARD_OPS | ||
|
||
include "mlir/Dialect/Shard/IR/ShardOps.td" | ||
|
||
#endif |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
from ._shard_ops_gen import * | ||
from ._shard_enum_gen import * |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
# RUN: %PYTHON %s | FileCheck %s | ||
|
||
from mlir.ir import * | ||
from mlir.dialects import shard | ||
from mlir.dialects import func | ||
|
||
|
||
def constructAndPrintInModule(f): | ||
print("\nTEST:", f.__name__) | ||
with Context(), Location.unknown(): | ||
module = Module.create() | ||
with InsertionPoint(module.body): | ||
f() | ||
print(module) | ||
return f | ||
|
||
|
||
# CHECK-LABEL: TEST: testShardGrid | ||
@constructAndPrintInModule | ||
def testShardGrid(): | ||
# Test creating shard grids with different shapes | ||
grid2d = shard.GridOp("grid_2d", [2, 2]) | ||
grid1d = shard.GridOp("grid_1d", [4]) | ||
grid_dynamic = shard.GridOp("grid_dynamic", [2, -1]) # -1 for dynamic dimension | ||
|
||
# CHECK: "shard.grid"() <{shape = array<i64: 2, 2>, sym_name = "grid_2d"}> : () -> () | ||
# CHECK: "shard.grid"() <{shape = array<i64: 4>, sym_name = "grid_1d"}> : () -> () | ||
# CHECK: "shard.grid"() <{shape = array<i64: 2, -1>, sym_name = "grid_dynamic"}> : () -> () | ||
|
||
|
||
# CHECK-LABEL: TEST: testCollectiveOperations | ||
@constructAndPrintInModule | ||
def testCollectiveOperations(): | ||
# Create grid and types | ||
grid = shard.GridOp("grid_2x2", [2, 2]) | ||
i32 = IntegerType.get_signless(32) | ||
input_type = RankedTensorType.get([4, 2], i32) | ||
gather_result_type = RankedTensorType.get([4, 4], i32) | ||
|
||
# Create a function to hold the operations | ||
func_type = FunctionType.get([input_type], [input_type]) | ||
test_func = func.FuncOp("test_collectives", func_type) | ||
|
||
with InsertionPoint(test_func.add_entry_block()): | ||
arg = test_func.entry_block.arguments[0] | ||
|
||
gather_op = shard.AllGatherOp( | ||
input=arg, | ||
grid=FlatSymbolRefAttr.get("grid_2x2"), | ||
grid_axes=ArrayAttr.get([IntegerAttr.get(i32, 1)]), | ||
gather_axis=IntegerAttr.get(i32, 1), | ||
result=gather_result_type, | ||
) | ||
|
||
reduce_op = shard.AllReduceOp( | ||
input=arg, | ||
grid=FlatSymbolRefAttr.get("grid_2x2"), | ||
reduction=shard.ReductionKind.Sum, | ||
result=input_type, | ||
) | ||
|
||
func.ReturnOp([reduce_op]) | ||
|
||
# CHECK: "shard.grid"() <{shape = array<i64: 2, 2>, sym_name = "grid_2x2"}> : () -> () | ||
# CHECK: "func.func"() <{function_type = (tensor<4x2xi32>) -> tensor<4x2xi32>, sym_name = "test_collectives"}> | ||
# CHECK: "shard.all_gather"({{.*}}) <{gather_axis = 1 : i32, grid = @grid_2x2}> : (tensor<4x2xi32>) -> tensor<4x4xi32> | ||
# CHECK: "shard.all_reduce"({{.*}}) <{grid = @grid_2x2, {{.*}} reduction = #shard<partial sum>}> : (tensor<4x2xi32>) -> tensor<4x2xi32> |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -981,6 +981,38 @@ filegroup( | |
], | ||
) | ||
|
||
##---------------------------------------------------------------------------## | ||
# Shard dialect. | ||
##---------------------------------------------------------------------------## | ||
|
||
gentbl_filegroup( | ||
name = "ShardOpsPyGen", | ||
tbl_outs = { | ||
"mlir/dialects/_shard_enum_gen.py": [ | ||
"-gen-python-enum-bindings", | ||
"-bind-dialect=shard", | ||
], | ||
"mlir/dialects/_shard_ops_gen.py": [ | ||
"-gen-python-op-bindings", | ||
"-bind-dialect=shard", | ||
], | ||
}, | ||
tblgen = "//mlir:mlir-tblgen", | ||
td_file = "mlir/dialects/ShardOps.td", | ||
deps = [ | ||
"//mlir:OpBaseTdFiles", | ||
"//mlir:ShardTdFiles", | ||
], | ||
) | ||
|
||
filegroup( | ||
name = "ShardOpsPyFiles", | ||
srcs = [ | ||
"mlir/dialects/shard.py", | ||
":ShardOpsPyGen", | ||
], | ||
) | ||
Comment on lines
+988
to
+1014
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. unless you work at Google (or you really actually do use bazel) you don't need to do this (bazel build is maintained by the users that actually use bazel) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for pointing this out! I'm actually interested in building my project with Bazel, so let's keep it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok sure np just gotta get someone that knows bazel (not me lol) to sign off. @jpienaar @superbobry this look right to you? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks @makslevental! Seems like Bazel build is fine. Can I get reviews please? @jpienaar @superbobry There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bazel changes look good.
Your original comment was correct: a working bazel build isn't a requirement for committing changes. So you don't really need much of a sign off here. If this change breaks the bazel build, that was already allowed :) |
||
|
||
##---------------------------------------------------------------------------## | ||
# Shape dialect. | ||
##---------------------------------------------------------------------------## | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think
gather_axis
should be anIndexAttr
looking at this line.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting - so how did this pass?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like we're not calling the verifiers on the Ops in these tests. I'll update the tests to be valid, but I also find it odd that they are passing. I've seen verifiers being called as a separate step after the Ops are built.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can use
module.operation.verify()
to verify it manually. Also note that, usually if it passes the verification, the IR will be printed in custom form (if there is a custom syntax) rather than generic form (which is shown in CHECKs here).