Skip to content

Commit f8c4137

Browse files
authored
StableHLO Builder APIs (#2846)
# Declarative MLIR Builder APIs Goal: Provide a builder that abstracts away the notion of location and insertion point for use cases that construct full graphs from C++. See `MlirBuilderTest.cpp` for examples. ## Usage The builders look fairly similar to XlaBuilder's declarative style, see `MlirBuilderTest.cpp` for a few example programs: ```c++ StablehloModuleBuilder mb; { // Build Main Func ScopedBuilderLocation loc(mb.get(), FileLineColLoc(mb.get(), "main.mlir")); func::FunctionBuilder fb(mb.get(), mb->getLoc(), "main"); auto type4xi64 = RankedTensorType::get({4}, fb.getOpBuilder().getI64Type()); auto arg0 = func::Argument(fb, type4xi64); auto cst = stablehlo::Constant(fb, 1); auto add = chlo::BroadcastAdd(arg0, cst); auto topkAndIndices = chlo::TopK(add, 2); auto broadcast = stablehlo::BroadcastInDim(topkAndIndices[0].getType(), cst, {}); auto equal = tosa::Equal(topkAndIndices[0], broadcast); func::Return(fb, {equal}); } mb->build()->dump(); // module { // func.func @main(%arg0: tensor<4xi64>) -> tensor<2xi1> { // %c = stablehlo.constant dense<1> : tensor<i64> // %0 = chlo.broadcast_add %arg0, %c : (tensor<4xi64>, tensor<i64>) -> tensor<4xi64> // %values, %indices = chlo.top_k(%0, k = 2) : tensor<4xi64> -> (tensor<2xi64>, tensor<2xi32>) // %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i64>) -> tensor<2xi64> // %2 = tosa.equal %values, %1 : (tensor<2xi64>, tensor<2xi64>) -> tensor<2xi1> // return %2 : tensor<2xi1> // } // } ```
1 parent 937150f commit f8c4137

25 files changed

+6397
-1
lines changed

BUILD.bazel

Lines changed: 207 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "gentbl_filegroup", "td_library")
15-
load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_library")
15+
load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_library", "cc_test")
1616

1717
package(
1818
default_visibility = ["//visibility:public"],
@@ -1781,3 +1781,209 @@ test_suite(
17811781
"//stablehlo/tests:stablehlo_tests",
17821782
],
17831783
)
1784+
1785+
### Builder APIs
1786+
1787+
#####
1788+
# TableGen
1789+
#####
1790+
1791+
cc_binary(
1792+
name = "mlir_builder_tblgen",
1793+
srcs = ["stablehlo/integrations/cpp/builder/MlirBuilderTblgen.cpp"],
1794+
deps = [
1795+
"@llvm-project//llvm:Support",
1796+
"@llvm-project//llvm:TableGen",
1797+
"@llvm-project//mlir:Support",
1798+
"@llvm-project//mlir:TableGen",
1799+
],
1800+
)
1801+
1802+
####
1803+
## MlirBuilder base class
1804+
####
1805+
1806+
cc_library(
1807+
name = "mlir_builder",
1808+
srcs = ["stablehlo/integrations/cpp/builder/MlirBuilder.cpp"],
1809+
hdrs = ["stablehlo/integrations/cpp/builder/MlirBuilder.h"],
1810+
strip_include_prefix = ".",
1811+
deps = [
1812+
":attr_type_builder_util",
1813+
"@llvm-project//llvm:Support",
1814+
"@llvm-project//mlir:IR",
1815+
"@llvm-project//mlir:Support",
1816+
],
1817+
)
1818+
1819+
cc_test(
1820+
name = "mlir_builder_test",
1821+
srcs = ["stablehlo/integrations/cpp/builder/MlirBuilderTest.cpp"],
1822+
deps = [
1823+
":attr_type_builder_util",
1824+
":chlo_builder",
1825+
":func_builder",
1826+
":mlir_builder",
1827+
":register",
1828+
":stablehlo_builder",
1829+
":stablehlo_ops",
1830+
"@llvm-project//llvm:Support",
1831+
"@llvm-project//mlir:FuncDialect",
1832+
"@llvm-project//mlir:IR",
1833+
"@llvm-project//mlir:Support",
1834+
"@llvm-project//mlir:TosaDialect",
1835+
"@llvm-project//third-party/unittest:gmock",
1836+
"@llvm-project//third-party/unittest:gtest",
1837+
],
1838+
)
1839+
1840+
####
1841+
## Attr / Type Builder Helpers
1842+
####
1843+
1844+
cc_library(
1845+
name = "attr_type_builder_util",
1846+
srcs = ["stablehlo/integrations/cpp/builder/AttrTypeBuilderUtil.cpp"],
1847+
hdrs = ["stablehlo/integrations/cpp/builder/AttrTypeBuilderUtil.h"],
1848+
strip_include_prefix = ".",
1849+
deps = [
1850+
"@llvm-project//llvm:Support",
1851+
"@llvm-project//mlir:IR",
1852+
"@llvm-project//mlir:Support",
1853+
],
1854+
)
1855+
1856+
cc_test(
1857+
name = "attr_type_builder_util_test",
1858+
srcs = ["stablehlo/integrations/cpp/builder/AttrTypeBuilderUtilTest.cpp"],
1859+
deps = [
1860+
":attr_type_builder_util",
1861+
"@llvm-project//llvm:Support",
1862+
"@llvm-project//mlir:IR",
1863+
"@llvm-project//mlir:Support",
1864+
"@llvm-project//third-party/unittest:gmock",
1865+
"@llvm-project//third-party/unittest:gtest",
1866+
],
1867+
)
1868+
1869+
#####
1870+
## Dialect-specific builders
1871+
####
1872+
1873+
gentbl_cc_library(
1874+
name = "chlo_builder_inc",
1875+
tbl_outs = {
1876+
"stablehlo/integrations/cpp/builder/ChloBuilder.cpp.inc": ["-gen-builder-defs"],
1877+
"stablehlo/integrations/cpp/builder/ChloBuilder.h.inc": ["-gen-builder-decls"],
1878+
"stablehlo/integrations/cpp/builder/ChloBuilder.md": ["-gen-builder-docs"],
1879+
},
1880+
tblgen = ":mlir_builder_tblgen",
1881+
td_file = ":stablehlo/dialect/ChloOps.td",
1882+
deps = [
1883+
":chlo_ops_td_files",
1884+
"@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
1885+
"@llvm-project//mlir:OpBaseTdFiles",
1886+
"@llvm-project//mlir:SideEffectInterfacesTdFiles",
1887+
],
1888+
)
1889+
1890+
cc_library(
1891+
name = "chlo_builder",
1892+
srcs = ["stablehlo/integrations/cpp/builder/ChloBuilder.cpp"],
1893+
hdrs = ["stablehlo/integrations/cpp/builder/ChloBuilder.h"],
1894+
strip_include_prefix = ".",
1895+
deps = [
1896+
":chlo_builder_inc",
1897+
":chlo_ops",
1898+
":mlir_builder",
1899+
"@llvm-project//llvm:Support",
1900+
"@llvm-project//mlir:IR",
1901+
"@llvm-project//mlir:Support",
1902+
],
1903+
)
1904+
1905+
gentbl_cc_library(
1906+
name = "func_builder_inc",
1907+
tbl_outs = {
1908+
"stablehlo/integrations/cpp/builder/FuncBuilder.cpp.inc": ["-gen-builder-defs"],
1909+
"stablehlo/integrations/cpp/builder/FuncBuilder.h.inc": ["-gen-builder-decls"],
1910+
"stablehlo/integrations/cpp/builder/FuncBuilder.md": ["-gen-builder-docs"],
1911+
},
1912+
tblgen = ":mlir_builder_tblgen",
1913+
td_file = "@llvm-project//mlir:include/mlir/Dialect/Func/IR/FuncOps.td",
1914+
deps = [
1915+
"@llvm-project//mlir:FuncTdFiles",
1916+
"@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
1917+
"@llvm-project//mlir:OpBaseTdFiles",
1918+
"@llvm-project//mlir:SideEffectInterfacesTdFiles",
1919+
],
1920+
)
1921+
1922+
cc_library(
1923+
name = "func_builder",
1924+
srcs = ["stablehlo/integrations/cpp/builder/FuncBuilder.cpp"],
1925+
hdrs = ["stablehlo/integrations/cpp/builder/FuncBuilder.h"],
1926+
strip_include_prefix = ".",
1927+
deps = [
1928+
":func_builder_inc",
1929+
":mlir_builder",
1930+
"@llvm-project//mlir:FuncDialect",
1931+
"@llvm-project//mlir:IR",
1932+
"@llvm-project//mlir:Support",
1933+
],
1934+
)
1935+
1936+
gentbl_cc_library(
1937+
name = "stablehlo_builder_inc",
1938+
tbl_outs = {
1939+
"stablehlo/integrations/cpp/builder/StablehloBuilder.cpp.inc": ["-gen-builder-defs"],
1940+
"stablehlo/integrations/cpp/builder/StablehloBuilder.h.inc": ["-gen-builder-decls"],
1941+
"stablehlo/integrations/cpp/builder/StablehloBuilder.md": ["-gen-builder-docs"],
1942+
},
1943+
tblgen = ":mlir_builder_tblgen",
1944+
td_file = ":stablehlo/dialect/StablehloOps.td",
1945+
deps = [
1946+
":base_td_files",
1947+
":stablehlo_ops_td_filegroup",
1948+
"@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
1949+
"@llvm-project//mlir:OpBaseTdFiles",
1950+
"@llvm-project//mlir:ShapeOpsTdFiles",
1951+
"@llvm-project//mlir:SideEffectInterfacesTdFiles",
1952+
],
1953+
)
1954+
1955+
cc_library(
1956+
name = "stablehlo_builder",
1957+
srcs = ["stablehlo/integrations/cpp/builder/StablehloBuilder.cpp"],
1958+
hdrs = ["stablehlo/integrations/cpp/builder/StablehloBuilder.h"],
1959+
strip_include_prefix = ".",
1960+
deps = [
1961+
":attr_type_builder_util",
1962+
":mlir_builder",
1963+
":stablehlo_builder_inc",
1964+
":stablehlo_ops",
1965+
":stablehlo_type_inference",
1966+
"@llvm-project//llvm:Support",
1967+
"@llvm-project//mlir:FuncDialect",
1968+
"@llvm-project//mlir:IR",
1969+
"@llvm-project//mlir:InferTypeOpInterface",
1970+
"@llvm-project//mlir:Support",
1971+
],
1972+
)
1973+
1974+
cc_test(
1975+
name = "stablehlo_builder_test",
1976+
srcs = ["stablehlo/integrations/cpp/builder/StablehloBuilderTest.cpp"],
1977+
deps = [
1978+
":attr_type_builder_util",
1979+
":func_builder",
1980+
":mlir_builder",
1981+
":register",
1982+
":stablehlo_builder",
1983+
":stablehlo_ops",
1984+
"@llvm-project//mlir:IR",
1985+
"@llvm-project//mlir:Support",
1986+
"@llvm-project//third-party/unittest:gmock",
1987+
"@llvm-project//third-party/unittest:gtest",
1988+
],
1989+
)

build_tools/github_actions/ci_build_docs.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ targets[":interpreter_pass_inc_gen_filegroup"]="bazel-bin/stablehlo/reference/in
4646
targets[":linalg_pass_inc_gen_filegroup"]="bazel-bin/stablehlo/conversions/linalg/transforms/stablehlo_linalg_passes.md"
4747
targets[":tosa_pass_inc_gen_filegroup"]="bazel-bin/stablehlo/conversions/tosa/transforms/stablehlo_tosa_passes.md"
4848
targets[":chlo_ops_inc_gen_filegroup"]="bazel-bin/stablehlo/dialect/chlo.md"
49+
targets[":stablehlo_builder_inc_filegroup"]="bazel-bin/stablehlo/integrations/cpp/builder/StablehloBuilder.md"
50+
targets[":chlo_builder_inc_filegroup"]="bazel-bin/stablehlo/integrations/cpp/builder/ChloBuilder.md"
51+
targets[":func_builder_inc_filegroup"]="bazel-bin/stablehlo/integrations/cpp/builder/FuncBuilder.md"
4952

5053
bazel build "${!targets[@]}"
5154

docs/_toc.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,14 @@ toc:
5757
path: /stablehlo/generated/chlo.md
5858
- title: The VHLO dialect
5959
path: /stablehlo/vhlo
60+
- title: C++ Builder APIs
61+
section:
62+
- title: StableHLO Builders
63+
path: /stablehlo/generated/StablehloBuilder.md
64+
- title: CHLO Builders
65+
path: /stablehlo/generated/ChloBuilder.md
66+
- title: Func Builders
67+
path: /stablehlo/generated/FuncBuilder.md
6068
- title: IDE setup tips
6169
path: /stablehlo/ide
6270
- title: StableHLO Passes

0 commit comments

Comments
 (0)