Skip to content

Commit 0a4e3c4

Browse files
Regenerate MLIR Bindings (#1456)
Co-authored-by: enzyme-ci-bot[bot] <78882869+enzyme-ci-bot[bot]@users.noreply.github.com>
1 parent d6b6ad0 commit 0a4e3c4

File tree

3 files changed

+68
-9
lines changed

3 files changed

+68
-9
lines changed

src/mlir/Dialects/Nvvm.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1593,19 +1593,22 @@ end
15931593
`elect_sync`
15941594
15951595
The `elect.sync` instruction elects one predicated active leader
1596-
thread from among a set of threads specified in membermask.
1597-
The membermask is set to `0xFFFFFFFF` for the current version
1598-
of this Op. The predicate result is set to `True` for the
1599-
leader thread, and `False` for all other threads.
1596+
thread from among a set of threads specified in the `membermask`.
1597+
When the `membermask` is not provided explicitly, a default value
1598+
of `0xFFFFFFFF` is used. The predicate result is set to `True` for
1599+
the leader thread, and `False` for all other threads.
16001600
16011601
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-elect-sync)
16021602
"""
1603-
function elect_sync(; pred::IR.Type, location=Location())
1603+
function elect_sync(
1604+
membermask=nothing::Union{Nothing,Value}; pred::IR.Type, location=Location()
1605+
)
16041606
op_ty_results = IR.Type[pred,]
16051607
operands = Value[]
16061608
owned_regions = Region[]
16071609
successors = Block[]
16081610
attributes = NamedAttribute[]
1611+
!isnothing(membermask) && push!(operands, membermask)
16091612

16101613
return create_operation(
16111614
"nvvm.elect.sync",

src/mlir/Dialects/TPU.jl

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -411,19 +411,18 @@ function enqueue_indirect_dma(
411411
source::Value,
412412
target::Value,
413413
offsets::Value,
414-
semaphore::Value;
414+
semaphore::Value,
415+
offset_filter=nothing::Union{Nothing,Value};
415416
add=nothing,
416-
offset_filter=nothing,
417417
location=Location(),
418418
)
419419
op_ty_results = IR.Type[]
420420
operands = Value[source, target, offsets, semaphore]
421421
owned_regions = Region[]
422422
successors = Block[]
423423
attributes = NamedAttribute[]
424+
!isnothing(offset_filter) && push!(operands, offset_filter)
424425
!isnothing(add) && push!(attributes, namedattribute("add", add))
425-
!isnothing(offset_filter) &&
426-
push!(attributes, namedattribute("offset_filter", offset_filter))
427426

428427
return create_operation(
429428
"tpu.enqueue_indirect_dma",
@@ -1579,6 +1578,25 @@ function wait_dma(semaphore::Value, ref::Value; location=Location())
15791578
)
15801579
end
15811580

1581+
function wait_indirect_dma(semaphore::Value, src::Value, dst::Value; location=Location())
1582+
op_ty_results = IR.Type[]
1583+
operands = Value[semaphore, src, dst]
1584+
owned_regions = Region[]
1585+
successors = Block[]
1586+
attributes = NamedAttribute[]
1587+
1588+
return create_operation(
1589+
"tpu.wait_indirect_dma",
1590+
location;
1591+
operands,
1592+
owned_regions,
1593+
successors,
1594+
attributes,
1595+
results=op_ty_results,
1596+
result_inference=false,
1597+
)
1598+
end
1599+
15821600
function weird(input::Value; output::IR.Type, location=Location())
15831601
op_ty_results = IR.Type[output,]
15841602
operands = Value[input,]

src/mlir/libMLIR_h.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2061,6 +2061,44 @@ function mlirBlockPrint(block, callback, userData)
20612061
)::Cvoid
20622062
end
20632063

2064+
"""
2065+
mlirBlockGetNumSuccessors(block)
2066+
2067+
Returns the number of successor blocks of the block.
2068+
"""
2069+
function mlirBlockGetNumSuccessors(block)
2070+
@ccall mlir_c.mlirBlockGetNumSuccessors(block::MlirBlock)::intptr_t
2071+
end
2072+
2073+
"""
2074+
mlirBlockGetSuccessor(block, pos)
2075+
2076+
Returns `pos`-th successor of the block.
2077+
"""
2078+
function mlirBlockGetSuccessor(block, pos)
2079+
@ccall mlir_c.mlirBlockGetSuccessor(block::MlirBlock, pos::intptr_t)::MlirBlock
2080+
end
2081+
2082+
"""
2083+
mlirBlockGetNumPredecessors(block)
2084+
2085+
Returns the number of predecessor blocks of the block.
2086+
"""
2087+
function mlirBlockGetNumPredecessors(block)
2088+
@ccall mlir_c.mlirBlockGetNumPredecessors(block::MlirBlock)::intptr_t
2089+
end
2090+
2091+
"""
2092+
mlirBlockGetPredecessor(block, pos)
2093+
2094+
Returns `pos`-th predecessor of the block.
2095+
2096+
WARNING: This getter is more expensive than the others here because the impl actually iterates the use-def chain (of block operands) anew for each indexed access.
2097+
"""
2098+
function mlirBlockGetPredecessor(block, pos)
2099+
@ccall mlir_c.mlirBlockGetPredecessor(block::MlirBlock, pos::intptr_t)::MlirBlock
2100+
end
2101+
20642102
"""
20652103
mlirValueIsNull(value)
20662104

0 commit comments

Comments
 (0)