Skip to content

Commit a75004b

Browse files
tomnatan30copybara-github
authored andcommitted
#sdy Copy over change to import_backend_func_calls
PiperOrigin-RevId: 738819740
1 parent 4278b71 commit a75004b

10 files changed

+315
-243
lines changed

shardy/round_trip_import/BUILD

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,9 @@ cc_library(
4242
)
4343

4444
cc_library(
45-
name = "import_backend_func_calls",
46-
srcs = ["import_backend_func_calls.cc"],
47-
hdrs = ["import_backend_func_calls.h"],
45+
name = "import_uninlineable_func_calls",
46+
srcs = ["import_uninlineable_func_calls.cc"],
47+
hdrs = ["import_uninlineable_func_calls.h"],
4848
deps = [
4949
":constants",
5050
":utils",
@@ -118,10 +118,10 @@ cc_library(
118118
srcs = ["pipelines.cc"],
119119
hdrs = ["pipelines.h"],
120120
deps = [
121-
":import_backend_func_calls",
122121
":import_callback_custom_calls",
123122
":import_sdy_custom_calls",
124123
":import_shardy_attrs",
124+
":import_uninlineable_func_calls",
125125
":shard_map_import",
126126
"@llvm-project//mlir:FuncDialect",
127127
"@llvm-project//mlir:Pass",

shardy/round_trip_import/constants.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ inline constexpr llvm::StringRef kFFIPythonGpuCallbackCustomCallTargetName =
4646
// The attribute name for backend config.
4747
inline constexpr llvm::StringRef kXlaBackendConfigAttr = "backend_config";
4848

49+
// The attribute name for inlineable.
50+
inline constexpr llvm::StringRef kXlaInlineableAttr = "inlineable";
51+
4952
// Attribute name for temporarily storing the Shardy sharding during HLO
5053
// sdy-round-trip. It cannot match the name `kShardingAttr` ("sdy.sharding"), as
5154
// during sdy-round-trip, going from HLO to StableHLO, the code removes

shardy/round_trip_import/import_backend_func_calls.cc

Lines changed: 0 additions & 150 deletions
This file was deleted.
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
/* Copyright 2025 The Shardy Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "shardy/round_trip_import/import_uninlineable_func_calls.h"
17+
18+
#include <cassert>
19+
#include <iterator>
20+
#include <memory>
21+
#include <string>
22+
23+
#include "llvm/ADT/DenseMap.h"
24+
#include "llvm/ADT/STLExtras.h"
25+
#include "llvm/Support/FormatVariadic.h"
26+
#include "llvm/Support/Threading.h"
27+
#include "mlir/Dialect/Func/IR/FuncOps.h"
28+
#include "mlir/IR/Attributes.h"
29+
#include "mlir/IR/BuiltinAttributes.h"
30+
#include "mlir/IR/BuiltinOps.h"
31+
#include "mlir/IR/Diagnostics.h"
32+
#include "mlir/IR/OperationSupport.h"
33+
#include "mlir/IR/PatternMatch.h"
34+
#include "mlir/IR/SymbolTable.h"
35+
#include "mlir/Pass/Pass.h"
36+
#include "mlir/Pass/PassRegistry.h"
37+
#include "mlir/Support/LLVM.h"
38+
#include "mlir/Support/TypeID.h"
39+
#include "mlir/Transforms/DialectConversion.h"
40+
#include "shardy/dialect/sdy/ir/constants.h"
41+
#include "shardy/dialect/sdy/ir/dialect.h"
42+
#include "shardy/dialect/sdy/ir/utils.h"
43+
#include "shardy/round_trip_import/constants.h"
44+
#include "shardy/round_trip_import/utils.h"
45+
46+
namespace mlir {
47+
namespace sdy {
48+
49+
namespace {
50+
51+
using func::CallOp;
52+
using func::FuncOp;
53+
54+
bool isInlineableCallOp(CallOp callOp) {
55+
if (hasFrontendAttr(callOp, kXlaBackendConfigAttr)) {
56+
return false;
57+
}
58+
auto inlineableAttr =
59+
tryGetFrontendAttr<BoolAttr>(callOp, kXlaInlineableAttr);
60+
return !inlineableAttr || inlineableAttr->getValue();
61+
}
62+
63+
void importCallOp(
64+
CallOp callOp,
65+
llvm::SmallDenseMap<StringRef, Region*>& calleeNameToMovedRegion,
66+
IRRewriter& rewriter, SymbolTable& symbolTable) {
67+
SmallVector<NamedAttribute> namedCompAttrs;
68+
llvm::copy_if(callOp->getDiscardableAttrs(),
69+
std::back_inserter(namedCompAttrs),
70+
[](const NamedAttribute& attr) {
71+
return attr.getName() != kShardingAttr;
72+
});
73+
74+
StringRef calleeName = callOp.getCallee();
75+
rewriter.setInsertionPoint(callOp);
76+
auto namedCompOp = rewriter.create<NamedComputationOp>(
77+
callOp->getLoc(), callOp->getResultTypes(), calleeName,
78+
callOp.getOperands(),
79+
/*inShardings=*/nullptr,
80+
/*outShardings=*/getShardingPerValue(callOp));
81+
namedCompOp->setAttrs(namedCompAttrs);
82+
83+
Region& namedCompRegion = namedCompOp.getRegion();
84+
if (auto movedRegionIt = calleeNameToMovedRegion.find(calleeName);
85+
movedRegionIt != calleeNameToMovedRegion.end()) {
86+
static llvm::once_flag onceFlag;
87+
emitOpWarningOnce(
88+
onceFlag, callOp,
89+
llvm::formatv("uninlineable function @{0} has multiple call ops, we "
90+
"need to clone the function body for each call",
91+
calleeName)
92+
.str());
93+
rewriter.cloneRegionBefore(*movedRegionIt->second, namedCompRegion,
94+
namedCompRegion.begin());
95+
} else {
96+
FuncOp funcOp = symbolTable.lookup<FuncOp>(calleeName);
97+
assert(funcOp &&
98+
("Failed to lookup function: " + std::string(calleeName)).c_str());
99+
inlineRegionAndConvertTerminatorOp<ReturnOp>(funcOp.getBody(),
100+
namedCompRegion);
101+
calleeNameToMovedRegion[calleeName] = &namedCompRegion;
102+
}
103+
104+
rewriter.replaceOp(callOp, namedCompOp);
105+
}
106+
107+
class ImportUninlineableFuncCallsPass
108+
: public PassWrapper<ImportUninlineableFuncCallsPass,
109+
OperationPass<ModuleOp>> {
110+
public:
111+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ImportUninlineableFuncCallsPass)
112+
113+
void runOnOperation() final {
114+
ModuleOp moduleOp = getOperation();
115+
IRRewriter rewriter(moduleOp.getContext());
116+
SymbolTable symbolTable(moduleOp);
117+
// For every callee name, the first CallOp encountered with that symbol will
118+
// move the body of the callee into the created NamedComputationOp, and map
119+
// the symbol name to the moved region. Subsequent CallOps with that symbol
120+
// will clone the mapped region.
121+
llvm::SmallDenseMap<StringRef, Region*> calleeNameToMovedRegion;
122+
123+
moduleOp->walk([&](CallOp op) {
124+
if (isInlineableCallOp(op)) {
125+
return;
126+
}
127+
importCallOp(op, calleeNameToMovedRegion, rewriter, symbolTable);
128+
});
129+
130+
// Erase all func ops that now have no call ops.
131+
for (auto [calleeName, _] : calleeNameToMovedRegion) {
132+
symbolTable.erase(symbolTable.lookup(calleeName));
133+
}
134+
}
135+
136+
StringRef getArgument() const override {
137+
return "xla-sdy-import-uninlineable-func-calls";
138+
}
139+
140+
StringRef getDescription() const override {
141+
return "Creates a pass that converts a `CallOp` with a `backend_config` "
142+
"or `inlineable=false` frontend attr to a `NamedComputationOp` with "
143+
"the function body inlined and the name of the callee.";
144+
}
145+
};
146+
147+
} // namespace
148+
149+
std::unique_ptr<Pass> createImportUninlineableFuncCallsPass() {
150+
return std::make_unique<ImportUninlineableFuncCallsPass>();
151+
}
152+
153+
void registerImportUninlineableFuncCallsPass() {
154+
registerPass(createImportUninlineableFuncCallsPass);
155+
}
156+
157+
} // namespace sdy
158+
} // namespace mlir

shardy/round_trip_import/import_backend_func_calls.h renamed to shardy/round_trip_import/import_uninlineable_func_calls.h

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
1515

16-
#ifndef SHARDY_ROUND_TRIP_IMPORT_IMPORT_BACKEND_FUNC_CALLS_H_
17-
#define SHARDY_ROUND_TRIP_IMPORT_IMPORT_BACKEND_FUNC_CALLS_H_
16+
#ifndef SHARDY_ROUND_TRIP_IMPORT_IMPORT_UNINLINEABLE_FUNC_CALLS_H_
17+
#define SHARDY_ROUND_TRIP_IMPORT_IMPORT_UNINLINEABLE_FUNC_CALLS_H_
1818

1919
#include <memory>
2020

@@ -23,19 +23,21 @@ limitations under the License.
2323
namespace mlir {
2424
namespace sdy {
2525

26-
// Creates a pass that converts a `CallOp` with a `backend_config` attr to a
27-
// `NamedComputationOp` with the function body inlined and name of the callee.
26+
// Creates a pass that converts a `CallOp` with a `backend_config` or
27+
// `inlineable=false` frontend attr to a `NamedComputationOp` with the function
28+
// body inlined and name of the callee.
2829
//
29-
// This pass is used to handle host offloading calls which are non inlined
30-
// functions that require the callee to be propagated through.
30+
// This pass is used to handle host offloading and GPU stream calls which are
31+
// non inlined functions that require the callee to be propagated through.
3132
//
32-
// NOTE: it assumes that there is a unique callee for each caller.
33-
std::unique_ptr<Pass> createImportBackendFuncCallsPass();
33+
// NOTE: In case there are multiple call ops for the same callee, we will clone
34+
// the function body for each call op and emit a warning.
35+
std::unique_ptr<mlir::Pass> createImportUninlineableFuncCallsPass();
3436

35-
// Register the xla-sdy-import-backend-func-calls pass.
36-
void registerImportBackendFuncCallsPass();
37+
// Register the xla-sdy-import-uninlineable-calls pass.
38+
void registerImportUninlineableFuncCallsPass();
3739

3840
} // namespace sdy
3941
} // namespace mlir
4042

41-
#endif // SHARDY_ROUND_TRIP_IMPORT_IMPORT_BACKEND_FUNC_CALLS_H_
43+
#endif // SHARDY_ROUND_TRIP_IMPORT_IMPORT_UNINLINEABLE_FUNC_CALLS_H_

0 commit comments

Comments
 (0)