Skip to content

Commit 2d74a0d

Browse files
Regenerate MLIR Bindings (#1519)
Co-authored-by: enzyme-ci-bot[bot] <78882869+enzyme-ci-bot[bot]@users.noreply.github.com>
1 parent 8f293d9 commit 2d74a0d

File tree

5 files changed

+194
-99
lines changed

5 files changed

+194
-99
lines changed

src/mlir/Dialects/Llvm.jl

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1542,6 +1542,82 @@ function icmp(
15421542
)
15431543
end
15441544

1545+
"""
1546+
`mlir_ifunc`
1547+
1548+
`llvm.mlir.ifunc` is a top level operation that defines a global ifunc.
1549+
It defines a new symbol and takes a symbol refering to a resolver function.
1550+
IFuncs can be called as regular functions. The function type is the same
1551+
as the IFuncType. The symbol is resolved at runtime by calling a resolver
1552+
function.
1553+
1554+
Examples:
1555+
1556+
```mlir
1557+
// IFuncs resolve a symbol at runtime using a resovler function.
1558+
llvm.mlir.ifunc external @foo: !llvm.func<f32 (i64)>, !llvm.ptr @resolver
1559+
1560+
llvm.func @foo_1(i64) -> f32
1561+
llvm.func @foo_2(i64) -> f32
1562+
1563+
llvm.func @resolve_foo() -> !llvm.ptr attributes {
1564+
%0 = llvm.mlir.addressof @foo_2 : !llvm.ptr
1565+
%1 = llvm.mlir.addressof @foo_1 : !llvm.ptr
1566+
1567+
// ... Logic selecting from foo_{1, 2}
1568+
1569+
// Return function pointer to the selected function
1570+
llvm.return %7 : !llvm.ptr
1571+
}
1572+
1573+
llvm.func @use_foo() {
1574+
// IFuncs are called as regular functions
1575+
%res = llvm.call @foo(%value) : i64 -> f32
1576+
}
1577+
```
1578+
"""
1579+
function mlir_ifunc(;
1580+
sym_name,
1581+
i_func_type,
1582+
resolver,
1583+
resolver_type,
1584+
linkage,
1585+
dso_local=nothing,
1586+
address_space=nothing,
1587+
unnamed_addr=nothing,
1588+
visibility_=nothing,
1589+
location=Location(),
1590+
)
1591+
op_ty_results = IR.Type[]
1592+
operands = Value[]
1593+
owned_regions = Region[]
1594+
successors = Block[]
1595+
attributes = NamedAttribute[
1596+
namedattribute("sym_name", sym_name),
1597+
namedattribute("i_func_type", i_func_type),
1598+
namedattribute("resolver", resolver),
1599+
namedattribute("resolver_type", resolver_type),
1600+
namedattribute("linkage", linkage),
1601+
]
1602+
!isnothing(dso_local) && push!(attributes, namedattribute("dso_local", dso_local))
1603+
!isnothing(address_space) &&
1604+
push!(attributes, namedattribute("address_space", address_space))
1605+
!isnothing(unnamed_addr) &&
1606+
push!(attributes, namedattribute("unnamed_addr", unnamed_addr))
1607+
!isnothing(visibility_) && push!(attributes, namedattribute("visibility_", visibility_))
1608+
1609+
return create_operation(
1610+
"llvm.mlir.ifunc",
1611+
location;
1612+
operands,
1613+
owned_regions,
1614+
successors,
1615+
attributes,
1616+
results=op_ty_results,
1617+
result_inference=false,
1618+
)
1619+
end
1620+
15451621
"""
15461622
`indirectbr`
15471623

src/mlir/Dialects/MemRef.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,11 @@ A set `nontemporal` attribute indicates that this load is not expected to
220220
be reused in the cache. For details, refer to the
221221
[https://llvm.org/docs/LangRef.html#load-instruction](LLVM load instruction).
222222
223+
An optional `alignment` attribute allows to specify the byte alignment of the
224+
load operation. It must be a positive power of 2. The operation must access
225+
memory at an address aligned to this boundary. Violations may lead to
226+
architecture-specific faults or performance penalties.
227+
A value of 0 indicates no specific alignment requirement.
223228
# Example
224229
225230
```mlir
@@ -231,6 +236,7 @@ function load(
231236
indices::Vector{Value};
232237
result=nothing::Union{Nothing,IR.Type},
233238
nontemporal=nothing,
239+
alignment=nothing,
234240
location=Location(),
235241
)
236242
op_ty_results = IR.Type[]
@@ -240,6 +246,7 @@ function load(
240246
attributes = NamedAttribute[]
241247
!isnothing(result) && push!(op_ty_results, result)
242248
!isnothing(nontemporal) && push!(attributes, namedattribute("nontemporal", nontemporal))
249+
!isnothing(alignment) && push!(attributes, namedattribute("alignment", alignment))
243250

244251
return create_operation(
245252
"memref.load",
@@ -1522,6 +1529,11 @@ A set `nontemporal` attribute indicates that this store is not expected to
15221529
be reused in the cache. For details, refer to the
15231530
[https://llvm.org/docs/LangRef.html#store-instruction](LLVM store instruction).
15241531
1532+
An optional `alignment` attribute allows to specify the byte alignment of the
1533+
store operation. It must be a positive power of 2. The operation must access
1534+
memory at an address aligned to this boundary. Violations may lead to
1535+
architecture-specific faults or performance penalties.
1536+
A value of 0 indicates no specific alignment requirement.
15251537
# Example
15261538
15271539
```mlir
@@ -1533,6 +1545,7 @@ function store(
15331545
memref::Value,
15341546
indices::Vector{Value};
15351547
nontemporal=nothing,
1548+
alignment=nothing,
15361549
location=Location(),
15371550
)
15381551
op_ty_results = IR.Type[]
@@ -1541,6 +1554,7 @@ function store(
15411554
successors = Block[]
15421555
attributes = NamedAttribute[]
15431556
!isnothing(nontemporal) && push!(attributes, namedattribute("nontemporal", nontemporal))
1557+
!isnothing(alignment) && push!(attributes, namedattribute("alignment", alignment))
15441558

15451559
return create_operation(
15461560
"memref.store",

src/mlir/Dialects/MosaicGPU.jl

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,28 @@ function async_load(
101101
)
102102
end
103103

104+
function async_load_tmem(
105+
source::Value; result_0=nothing::Union{Nothing,IR.Type}, location=Location()
106+
)
107+
op_ty_results = IR.Type[]
108+
operands = Value[source,]
109+
owned_regions = Region[]
110+
successors = Block[]
111+
attributes = NamedAttribute[]
112+
!isnothing(result_0) && push!(op_ty_results, result_0)
113+
114+
return create_operation(
115+
"mosaic_gpu.async_load_tmem",
116+
location;
117+
operands,
118+
owned_regions,
119+
successors,
120+
attributes,
121+
results=(length(op_ty_results) == 0 ? nothing : op_ty_results),
122+
result_inference=(length(op_ty_results) == 0 ? true : false),
123+
)
124+
end
125+
104126
"""
105127
`async_store`
106128
@@ -157,6 +179,25 @@ function async_store(
157179
)
158180
end
159181

182+
function async_store_tmem(source::Value, destination::Value; location=Location())
183+
op_ty_results = IR.Type[]
184+
operands = Value[source, destination]
185+
owned_regions = Region[]
186+
successors = Block[]
187+
attributes = NamedAttribute[]
188+
189+
return create_operation(
190+
"mosaic_gpu.async_store_tmem",
191+
location;
192+
operands,
193+
owned_regions,
194+
successors,
195+
attributes,
196+
results=op_ty_results,
197+
result_inference=false,
198+
)
199+
end
200+
160201
"""
161202
`broadcast_in_dim`
162203
@@ -511,6 +552,36 @@ function tmem_dealloc(tmem_ref::Value; location=Location())
511552
)
512553
end
513554

555+
"""
556+
`tmem_relinquish_alloc_permit`
557+
558+
The instruction specifies that the CTA of the executing thread is
559+
relinquishing the right to allocate Tensor Memory. So, it is illegal for a
560+
CTA to perform `tmem_alloc` after any of its constituent threads execute
561+
`tmem_relinquish_alloc_permit`.
562+
563+
If `collective` is `true`, applies to collective TMEM allocations.
564+
"""
565+
function tmem_relinquish_alloc_permit(; collective=nothing, location=Location())
566+
op_ty_results = IR.Type[]
567+
operands = Value[]
568+
owned_regions = Region[]
569+
successors = Block[]
570+
attributes = NamedAttribute[]
571+
!isnothing(collective) && push!(attributes, namedattribute("collective", collective))
572+
573+
return create_operation(
574+
"mosaic_gpu.tmem_relinquish_alloc_permit",
575+
location;
576+
operands,
577+
owned_regions,
578+
successors,
579+
attributes,
580+
results=op_ty_results,
581+
result_inference=false,
582+
)
583+
end
584+
514585
"""
515586
`wgmma`
516587

src/mlir/Dialects/Triton.jl

Lines changed: 24 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -127,31 +127,6 @@ function func(;
127127
)
128128
end
129129

130-
"""
131-
`reinterpret_tensor_descriptor`
132-
133-
This Op exists to help the transition from untyped raw TMA objects to typed Tensor descriptor objects.
134-
Ideally, we can remove this once the APIs are fully fleshed out.
135-
"""
136-
function reinterpret_tensor_descriptor(rawDesc::Value; result::IR.Type, location=Location())
137-
op_ty_results = IR.Type[result,]
138-
operands = Value[rawDesc,]
139-
owned_regions = Region[]
140-
successors = Block[]
141-
attributes = NamedAttribute[]
142-
143-
return create_operation(
144-
"tt.reinterpret_tensor_descriptor",
145-
location;
146-
operands,
147-
owned_regions,
148-
successors,
149-
attributes,
150-
results=op_ty_results,
151-
result_inference=false,
152-
)
153-
end
154-
155130
"""
156131
`return_`
157132
@@ -736,79 +711,6 @@ function expand_dims(
736711
)
737712
end
738713

739-
function experimental_tensormap_create(
740-
desc_ptr::Value,
741-
global_address::Value,
742-
box_dim::Vector{Value},
743-
global_dim::Vector{Value},
744-
global_stride::Vector{Value},
745-
element_stride::Vector{Value};
746-
elem_type,
747-
interleave_layout,
748-
swizzle_mode,
749-
fill_mode,
750-
location=Location(),
751-
)
752-
op_ty_results = IR.Type[]
753-
operands = Value[
754-
desc_ptr,
755-
global_address,
756-
box_dim...,
757-
global_dim...,
758-
global_stride...,
759-
element_stride...,
760-
]
761-
owned_regions = Region[]
762-
successors = Block[]
763-
attributes = NamedAttribute[
764-
namedattribute("elem_type", elem_type),
765-
namedattribute("interleave_layout", interleave_layout),
766-
namedattribute("swizzle_mode", swizzle_mode),
767-
namedattribute("fill_mode", fill_mode),
768-
]
769-
push!(
770-
attributes,
771-
operandsegmentsizes([
772-
1,
773-
1,
774-
length(box_dim),
775-
length(global_dim),
776-
length(global_stride),
777-
length(element_stride),
778-
]),
779-
)
780-
781-
return create_operation(
782-
"tt.experimental_tensormap_create",
783-
location;
784-
operands,
785-
owned_regions,
786-
successors,
787-
attributes,
788-
results=op_ty_results,
789-
result_inference=false,
790-
)
791-
end
792-
793-
function experimental_tensormap_fenceproxy_acquire(desc_ptr::Value; location=Location())
794-
op_ty_results = IR.Type[]
795-
operands = Value[desc_ptr,]
796-
owned_regions = Region[]
797-
successors = Block[]
798-
attributes = NamedAttribute[]
799-
800-
return create_operation(
801-
"tt.experimental_tensormap_fenceproxy_acquire",
802-
location;
803-
operands,
804-
owned_regions,
805-
successors,
806-
attributes,
807-
results=op_ty_results,
808-
result_inference=false,
809-
)
810-
end
811-
812714
"""
813715
`extern_elementwise`
814716
@@ -966,12 +868,15 @@ Return the histogram of the input tensor. The number of bins is equal to
966868
the dimension of the output tensor. Each bins has a width of 1 and bins
967869
start at 0.
968870
"""
969-
function histogram(src::Value; result::IR.Type, location=Location())
871+
function histogram(
872+
src::Value, mask=nothing::Union{Nothing,Value}; result::IR.Type, location=Location()
873+
)
970874
op_ty_results = IR.Type[result,]
971875
operands = Value[src,]
972876
owned_regions = Region[]
973877
successors = Block[]
974878
attributes = NamedAttribute[]
879+
!isnothing(mask) && push!(operands, mask)
975880

976881
return create_operation(
977882
"tt.histogram",
@@ -1566,4 +1471,24 @@ function trans(
15661471
)
15671472
end
15681473

1474+
function unsplat(src::Value; result=nothing::Union{Nothing,IR.Type}, location=Location())
1475+
op_ty_results = IR.Type[]
1476+
operands = Value[src,]
1477+
owned_regions = Region[]
1478+
successors = Block[]
1479+
attributes = NamedAttribute[]
1480+
!isnothing(result) && push!(op_ty_results, result)
1481+
1482+
return create_operation(
1483+
"tt.unsplat",
1484+
location;
1485+
operands,
1486+
owned_regions,
1487+
successors,
1488+
attributes,
1489+
results=(length(op_ty_results) == 0 ? nothing : op_ty_results),
1490+
result_inference=(length(op_ty_results) == 0 ? true : false),
1491+
)
1492+
end
1493+
15691494
end # tt

0 commit comments

Comments
 (0)