Skip to content

Commit 7a9fd20

Browse files
authored
chore: cpp side changes for tf SavedModel export (#1427)
1 parent 97c7119 commit 7a9fd20

File tree

4 files changed

+114
-3
lines changed

4 files changed

+114
-3
lines changed

.github/workflows/regenerate-mlir-bindings.yml

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,17 @@ jobs:
1616
with:
1717
version: '1.10'
1818
- uses: julia-actions/cache@v2
19-
- uses: actions/checkout@v4
19+
- uses: bazel-contrib/[email protected]
20+
name: Set up Bazel
2021
with:
21-
ref: main
22+
# Avoid downloading Bazel every time.
23+
bazelisk-cache: true
24+
# Store build cache per workflow.
25+
disk-cache: ${{ github.workflow }}-${{ matrix.os }}-${{ matrix.version }}
26+
# Share repository cache between workflows.
27+
repository-cache: true
28+
bazelisk-version: 1.x
29+
- uses: actions/checkout@v4
2230
- name: Install JuliaFormatter.jl
2331
shell: julia --color=yes {0}
2432
run: |

deps/ReactantExtra/API.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,17 @@ extern "C" MlirOperation mlirOperationParse(MlirContext ctx, MlirBlock block,
270270
.release()};
271271
}
272272

273+
extern "C" MlirType mlirGetFunctionTypeFromOperation(MlirOperation op) {
274+
if (auto funcOp = dyn_cast<mlir::FunctionOpInterface>(unwrap(op))) {
275+
return wrap(funcOp.getFunctionType());
276+
}
277+
ReactantThrowError("Not a function op");
278+
}
279+
280+
extern "C" bool mlirIsFunctionOpInterface(MlirOperation op) {
281+
return llvm::isa<mlir::FunctionOpInterface>(unwrap(op));
282+
}
283+
273284
// TODO mlirComplexAttrGetnValue
274285
// TODO extern "C" MlirTypeID mlirComplexAttrGetTypeID(void) { return
275286
// wrap(complex::NumberAttr::getTypeID()); }

deps/ReactantExtra/BUILD

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -888,6 +888,8 @@ cc_library(
888888
"-Wl,-exported_symbol,_hlo_sharding_*",
889889
"-Wl,-exported_symbol,_free_ifrt_sharding",
890890
"-Wl,-exported_symbol,_addSdyPropagationPipeline",
891+
"-Wl,-exported_symbol,_mlirGetFunctionTypeFromOperation",
892+
"-Wl,-exported_symbol,_mlirIsFunctionOpInterface",
891893
],
892894
}),
893895
linkstatic = True,
@@ -1434,6 +1436,9 @@ genrule(
14341436
"@llvm-project//mlir:AsyncPassIncGen_filegroup",
14351437
"@llvm-project//mlir:GPUPassIncGen_filegroup",
14361438
"@stablehlo//:stablehlo/integrations/c/StablehloAttributes.h",
1439+
"@stablehlo//:stablehlo/integrations/c/StablehloDialect.h",
1440+
"@stablehlo//:stablehlo/integrations/c/StablehloDialectApi.h",
1441+
"@stablehlo//:stablehlo/integrations/c/StablehloTypes.h",
14371442
"@shardy//shardy/integrations/c:attributes.h",
14381443
"//:Project.toml",
14391444
"//:Manifest.toml",
@@ -1442,7 +1447,7 @@ genrule(
14421447
"//:make.jl",
14431448
],
14441449
outs = ["libMLIR_h.jl"],
1445-
cmd = "$$JULIA \"--project=$(location //:Project.toml)\" \"$(location //:make.jl)\" \"$(location @llvm-project//mlir:include/mlir-c/Bindings/Python/Interop.h)\" \"$(location @llvm-project//llvm:include/llvm-c/Support.h)\" \"$(locations @llvm-project//mlir:ConversionPassIncGen_filegroup)\" \"$(location @stablehlo//:stablehlo/integrations/c/StablehloAttributes.h)\" \"$(location @shardy//shardy/integrations/c:attributes.h)\" \"$@\"",
1450+
cmd = "$$JULIA \"--color=yes\" \"--project=$(location //:Project.toml)\" \"$(location //:make.jl)\" \"$(location @llvm-project//mlir:include/mlir-c/Bindings/Python/Interop.h)\" \"$(location @llvm-project//llvm:include/llvm-c/Support.h)\" \"$(locations @llvm-project//mlir:ConversionPassIncGen_filegroup)\" \"$(location @stablehlo//:stablehlo/integrations/c/StablehloAttributes.h)\" \"$(location @shardy//shardy/integrations/c:attributes.h)\" \"$@\"",
14461451
tags = [
14471452
"jlrule",
14481453
],

src/mlir/libMLIR_h.jl

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10507,6 +10507,93 @@ function stablehloResultAccuracyAttrGetMode(attr)
1050710507
@ccall mlir_c.stablehloResultAccuracyAttrGetMode(attr::MlirAttribute)::MlirAttribute
1050810508
end
1050910509

10510+
function mlirGetDialectHandle__stablehlo__()
10511+
@ccall mlir_c.mlirGetDialectHandle__stablehlo__()::MlirDialectHandle
10512+
end
10513+
10514+
function stablehloGetApiVersion()
10515+
@ccall mlir_c.stablehloGetApiVersion()::Cint
10516+
end
10517+
10518+
@cenum MlirStablehloCompatibilityRequirement::UInt32 begin
10519+
NONE = 0x0000000000000000
10520+
WEEK_4 = 0x0000000000000001
10521+
WEEK_12 = 0x0000000000000002
10522+
MAX = 0x0000000000000003
10523+
end
10524+
10525+
function stablehloVersionFromCompatibilityRequirement(requirement, callback, userData)
10526+
@ccall mlir_c.stablehloVersionFromCompatibilityRequirement(
10527+
requirement::MlirStablehloCompatibilityRequirement,
10528+
callback::MlirStringCallback,
10529+
userData::Ptr{Cvoid},
10530+
)::Cvoid
10531+
end
10532+
10533+
function stablehloGetCurrentVersion(callback, userData)
10534+
@ccall mlir_c.stablehloGetCurrentVersion(
10535+
callback::MlirStringCallback, userData::Ptr{Cvoid}
10536+
)::Cvoid
10537+
end
10538+
10539+
function stablehloGetMinimumVersion(callback, userData)
10540+
@ccall mlir_c.stablehloGetMinimumVersion(
10541+
callback::MlirStringCallback, userData::Ptr{Cvoid}
10542+
)::Cvoid
10543+
end
10544+
10545+
function stablehloGetSmallerVersion(version1, version2, callback, userData)
10546+
@ccall mlir_c.stablehloGetSmallerVersion(
10547+
version1::MlirStringRef,
10548+
version2::MlirStringRef,
10549+
callback::MlirStringCallback,
10550+
userData::Ptr{Cvoid},
10551+
)::MlirLogicalResult
10552+
end
10553+
10554+
function stablehloSerializePortableArtifactFromStringRef(
10555+
moduleStr, targetVersion, callback, userData
10556+
)
10557+
@ccall mlir_c.stablehloSerializePortableArtifactFromStringRef(
10558+
moduleStr::MlirStringRef,
10559+
targetVersion::MlirStringRef,
10560+
callback::MlirStringCallback,
10561+
userData::Ptr{Cvoid},
10562+
)::MlirLogicalResult
10563+
end
10564+
10565+
function stablehloSerializePortableArtifactFromModule(
10566+
moduleStr, targetVersion, callback, userData, allowOtherDialects
10567+
)
10568+
@ccall mlir_c.stablehloSerializePortableArtifactFromModule(
10569+
moduleStr::MlirModule,
10570+
targetVersion::MlirStringRef,
10571+
callback::MlirStringCallback,
10572+
userData::Ptr{Cvoid},
10573+
allowOtherDialects::Bool,
10574+
)::MlirLogicalResult
10575+
end
10576+
10577+
function stablehloDeserializePortableArtifact(artifactStr, callback, userData)
10578+
@ccall mlir_c.stablehloDeserializePortableArtifact(
10579+
artifactStr::MlirStringRef, callback::MlirStringCallback, userData::Ptr{Cvoid}
10580+
)::MlirLogicalResult
10581+
end
10582+
10583+
function stablehloDeserializePortableArtifactNoError(artifactStr, ctx)
10584+
@ccall mlir_c.stablehloDeserializePortableArtifactNoError(
10585+
artifactStr::MlirStringRef, ctx::MlirContext
10586+
)::MlirModule
10587+
end
10588+
10589+
function stablehloTokenTypeGet(ctx)
10590+
@ccall mlir_c.stablehloTokenTypeGet(ctx::MlirContext)::MlirType
10591+
end
10592+
10593+
function stablehloTypeIsAToken(type)
10594+
@ccall mlir_c.stablehloTypeIsAToken(type::MlirType)::Bool
10595+
end
10596+
1051010597
function sdyAttributeIsAMeshAxisAttr(attr)
1051110598
@ccall mlir_c.sdyAttributeIsAMeshAxisAttr(attr::MlirAttribute)::Bool
1051210599
end

0 commit comments

Comments
 (0)