Skip to content

Commit 6a760a6

Browse files
Remove type piracy and enable function wrapping (#505)
Ignore MPS functions for now since MPS does not seem to have a dylib.
1 parent b949b14 commit 6a760a6

File tree

4 files changed

+119
-20
lines changed

4 files changed

+119
-20
lines changed

lib/mtl/libmtl.jl

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,33 +19,83 @@ struct MTLTextureSwizzleChannels
1919
alpha::MTLTextureSwizzle
2020
end
2121

22+
function MTLTextureSwizzleChannelsMake(r, g, b, a)
23+
@ccall (Symbol("/System/Library/Frameworks/Metal.framework/Resources/BridgeSupport/Metal.dylib")).MTLTextureSwizzleChannelsMake(r::MTLTextureSwizzle,
24+
g::MTLTextureSwizzle,
25+
b::MTLTextureSwizzle,
26+
a::MTLTextureSwizzle)::MTLTextureSwizzleChannels
27+
end
28+
2229
struct MTLOrigin
2330
x::NSUInteger
2431
y::NSUInteger
2532
z::NSUInteger
2633
MTLOrigin(x=0, y=0, z=0) = new(x, y, z)
2734
end
2835

36+
function MTLOriginMake(x, y, z)
37+
@ccall (Symbol("/System/Library/Frameworks/Metal.framework/Resources/BridgeSupport/Metal.dylib")).MTLOriginMake(x::NSUInteger,
38+
y::NSUInteger,
39+
z::NSUInteger)::MTLOrigin
40+
end
41+
2942
struct MTLSize
3043
width::NSUInteger
3144
height::NSUInteger
3245
depth::NSUInteger
3346
MTLSize(w=1, h=1, d=1) = new(w, h, d)
3447
end
3548

49+
function MTLSizeMake(width, height, depth)
50+
@ccall (Symbol("/System/Library/Frameworks/Metal.framework/Resources/BridgeSupport/Metal.dylib")).MTLSizeMake(width::NSUInteger,
51+
height::NSUInteger,
52+
depth::NSUInteger)::MTLSize
53+
end
54+
3655
struct MTLRegion
3756
origin::MTLOrigin
3857
size::MTLSize
3958
MTLRegion(origin=MTLOrigin(), size=MTLSize()) = new(origin, size)
4059
end
4160

61+
function MTLRegionMake1D(x, width)
62+
@ccall (Symbol("/System/Library/Frameworks/Metal.framework/Resources/BridgeSupport/Metal.dylib")).MTLRegionMake1D(x::NSUInteger,
63+
width::NSUInteger)::MTLRegion
64+
end
65+
66+
function MTLRegionMake2D(x, y, width, height)
67+
@ccall (Symbol("/System/Library/Frameworks/Metal.framework/Resources/BridgeSupport/Metal.dylib")).MTLRegionMake2D(x::NSUInteger,
68+
y::NSUInteger,
69+
width::NSUInteger,
70+
height::NSUInteger)::MTLRegion
71+
end
72+
73+
function MTLRegionMake3D(x, y, z, width, height, depth)
74+
@ccall (Symbol("/System/Library/Frameworks/Metal.framework/Resources/BridgeSupport/Metal.dylib")).MTLRegionMake3D(x::NSUInteger,
75+
y::NSUInteger,
76+
z::NSUInteger,
77+
width::NSUInteger,
78+
height::NSUInteger,
79+
depth::NSUInteger)::MTLRegion
80+
end
81+
4282
struct MTLSamplePosition
4383
x::Cfloat
4484
y::Cfloat
4585
end
4686

87+
function MTLSamplePositionMake(x, y)
88+
@ccall (Symbol("/System/Library/Frameworks/Metal.framework/Resources/BridgeSupport/Metal.dylib")).MTLSamplePositionMake(x::Cfloat,
89+
y::Cfloat)::MTLSamplePosition
90+
end
91+
4792
const MTLCoordinate2D = MTLSamplePosition
4893

94+
function MTLCoordinate2DMake(x, y)
95+
@ccall (Symbol("/System/Library/Frameworks/Metal.framework/Resources/BridgeSupport/Metal.dylib")).MTLCoordinate2DMake(x::Cfloat,
96+
y::Cfloat)::MTLCoordinate2D
97+
end
98+
4999
struct MTLResourceID
50100
_impl::UInt64
51101
end
@@ -654,6 +704,13 @@ struct MTLClearColor
654704
alpha::Cdouble
655705
end
656706

707+
function MTLClearColorMake(red, green, blue, alpha)
708+
@ccall (Symbol("/System/Library/Frameworks/Metal.framework/Resources/BridgeSupport/Metal.dylib")).MTLClearColorMake(red::Cdouble,
709+
green::Cdouble,
710+
blue::Cdouble,
711+
alpha::Cdouble)::MTLClearColor
712+
end
713+
657714
@cenum MTLLoadAction::UInt64 begin
658715
MTLLoadActionDontCare = 0x0000000000000000
659716
MTLLoadActionLoad = 0x0000000000000001
@@ -1137,13 +1194,26 @@ end
11371194

11381195
const MTLPackedFloat3 = _MTLPackedFloat3
11391196

1197+
function MTLPackedFloat3Make(x, y, z)
1198+
@ccall (Symbol("/System/Library/Frameworks/Metal.framework/Resources/BridgeSupport/Metal.dylib")).MTLPackedFloat3Make(x::Cfloat,
1199+
y::Cfloat,
1200+
z::Cfloat)::MTLPackedFloat3
1201+
end
1202+
11401203
struct MTLPackedFloatQuaternion
11411204
x::Cfloat
11421205
y::Cfloat
11431206
z::Cfloat
11441207
w::Cfloat
11451208
end
11461209

1210+
function MTLPackedFloatQuaternionMake(x, y, z, w)
1211+
@ccall (Symbol("/System/Library/Frameworks/Metal.framework/Resources/BridgeSupport/Metal.dylib")).MTLPackedFloatQuaternionMake(x::Cfloat,
1212+
y::Cfloat,
1213+
z::Cfloat,
1214+
w::Cfloat)::MTLPackedFloatQuaternion
1215+
end
1216+
11471217
struct _MTLPackedFloat4x3
11481218
columns::NTuple{4,MTLPackedFloat3}
11491219
end
@@ -1308,6 +1378,11 @@ struct MTLIndirectCommandBufferExecutionRange
13081378
length::UInt32
13091379
end
13101380

1381+
function MTLIndirectCommandBufferExecutionRangeMake(location, length)
1382+
@ccall (Symbol("/System/Library/Frameworks/Metal.framework/Resources/BridgeSupport/Metal.dylib")).MTLIndirectCommandBufferExecutionRangeMake(location::UInt32,
1383+
length::UInt32)::MTLIndirectCommandBufferExecutionRange
1384+
end
1385+
13111386
@cenum MTLFunctionLogType::UInt64 begin
13121387
MTLFunctionLogTypeValidation = 0x0000000000000000
13131388
end
@@ -1395,3 +1470,23 @@ end
13951470
end
13961471

13971472
const MTLIOCompressionContext = Ptr{Cvoid}
1473+
1474+
function MTLIOCompressionContextDefaultChunkSize()
1475+
@ccall (Symbol("/System/Library/Frameworks/Metal.framework/Resources/BridgeSupport/Metal.dylib")).MTLIOCompressionContextDefaultChunkSize()::Csize_t
1476+
end
1477+
1478+
function MTLIOCreateCompressionContext(path, type, chunkSize)
1479+
@ccall (Symbol("/System/Library/Frameworks/Metal.framework/Resources/BridgeSupport/Metal.dylib")).MTLIOCreateCompressionContext(path::Cstring,
1480+
type::MTLIOCompressionMethod,
1481+
chunkSize::Csize_t)::MTLIOCompressionContext
1482+
end
1483+
1484+
function MTLIOCompressionContextAppendData(context, data, size)
1485+
@ccall (Symbol("/System/Library/Frameworks/Metal.framework/Resources/BridgeSupport/Metal.dylib")).MTLIOCompressionContextAppendData(context::MTLIOCompressionContext,
1486+
data::Ptr{Cvoid},
1487+
size::Csize_t)::Cvoid
1488+
end
1489+
1490+
function MTLIOFlushAndDestroyCompressionContext(context)
1491+
@ccall (Symbol("/System/Library/Frameworks/Metal.framework/Resources/BridgeSupport/Metal.dylib")).MTLIOFlushAndDestroyCompressionContext(context::MTLIOCompressionContext)::MTLIOCompressionStatus
1492+
end

res/wrap/libmps.toml

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,16 @@ printer_blacklist = [
1010
"CF.*",
1111
"MTL.*",
1212
"NS.*",
13-
"BOOL"
13+
"BOOL",
14+
# Not sure how to access the MPS functions so don't wrap for now
15+
"MPSDataTypeBitsCount",
16+
"MPSSizeofMPSDataType",
17+
"MPSSizeofMPSDataType",
18+
"MPSFindIntegerDivisionParams",
19+
"MPSGetCustomKernelMaxBatchSize",
20+
"MPSGetCustomKernelBatchedDestinationIndex",
21+
"MPSGetCustomKernelBatchedSourceIndex",
22+
"MPSGetCustomKernelBroadcastSourceIndex",
1423
]
1524

1625
[codegen]

res/wrap/libmtl.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[general]
2-
library_name = "libmtl"
2+
library_name = "Symbol(\"/System/Library/Frameworks/Metal.framework/Resources/BridgeSupport/Metal.dylib\")"
33
output_file_path = "../../lib/mtl/libmtl.jl"
44

55
generate_isystem_symbols = false

res/wrap/wrap.jl

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ using Clang_jll
33
Clang_jll.libclang = "/Applications/Xcode.app/Contents/Frameworks/libclang.dylib"
44

55
using Clang.Generators
6-
using Clang.Generators: LinkEnumAlias
76
using Clang
87
using Glob
98
using JLD2
@@ -13,11 +12,8 @@ using Logging
1312
# Use system SDK
1413
SDK_PATH = `xcrun --show-sdk-path` |> open |> readchomp |> String
1514

16-
# Hack to prevent printing of functions for now
17-
Generators.skip_check(dag::Generators.ExprDAG, node::Generators.ExprNode{Generators.FunctionProto}) = true
18-
1915
main(name::AbstractString; kwargs...) = main([name]; kwargs...)
20-
function main(names=["all"]; sdk_path=SDK_PATH)
16+
function main(names::AbstractVector=["all"]; sdk_path=SDK_PATH)
2117
path_to_framework(framework) = joinpath(sdk_path, "System/Library/Frameworks/",framework*".framework","Headers")
2218
path_to_mps_framework(framework) = joinpath(sdk_path, "System/Library/Frameworks/","MetalPerformanceShaders.framework","Frameworks",framework*".framework","Headers")
2319

@@ -43,16 +39,6 @@ function main(names=["all"]; sdk_path=SDK_PATH)
4339
push!(ctxs, tctx)
4440
end
4541

46-
# if "all" in names || "libfoundation" in names || "foundation" in names
47-
# fwpath = path_to_framework("Foundation")
48-
# tctx = wrap("libfoundation", joinpath(foundation, "Foundation.h");, defines=["__builtin_va_list"])
49-
# push!(ctxs, tctx)
50-
# end
51-
# if "all" in names || "libcf" in names || "cf" in names
52-
# fwpath = path_to_framework("CoreFoundation")
53-
# tctx = wrap("libfoundation", joinpath(fwpath, "CoreFoundation.h");, defines=["__builtin_va_list"])
54-
# push!(ctxs, tctx)
55-
# end
5642
return ctxs
5743
end
5844

@@ -119,16 +105,22 @@ function create_objc_context(headers::Vector, args::Vector=String[], options::Di
119105
"/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain"
120106
]
121107

108+
regen = if haskey(options, "general") && haskey(options["general"], "regenerate_dependent_headers")
109+
options["general"]["regenerate_dependent_headers"]
110+
else
111+
false
112+
end
113+
122114
# Since the framework we're wrapping is a system header,
123115
# find all dependent headers, then remove all but the relevant ones
124116
# also temporarily disable logging
125117
dep_headers_fname = if haskey(options, "general") && haskey(options["general"], "library_name")
126-
options["general"]["library_name"]*".JLD2"
118+
splitext(splitpath(options["general"]["output_file_path"])[end])[1]*".JLD2"
127119
else
128120
nothing
129121
end
130122
Base.CoreLogging._min_enabled_level[] = Logging.Info+1
131-
dependent_headers = if !isnothing(dep_headers_fname) && isfile(dep_headers_fname)
123+
dependent_headers = if !regen && !isnothing(dep_headers_fname) && isfile(dep_headers_fname)
132124
JLD2.load(dep_headers_fname, "dep_headers")
133125
else
134126
all_headers = find_dependent_headers(headers,args,[])
@@ -137,7 +129,10 @@ function create_objc_context(headers::Vector, args::Vector=String[], options::Di
137129
target_framework = "/"*joinpath(Sys.splitpath(header)[end-2:end-1])
138130
dep_headers = append!(dep_headers, filter(s -> occursin(target_framework, s), all_headers))
139131
end
140-
JLD2.@save dep_headers_fname dep_headers
132+
if haskey(options, "general") && haskey(options["general"], "extra_target_headers")
133+
append!(dep_headers, options["general"]["extra_target_headers"])
134+
end
135+
regen || JLD2.@save dep_headers_fname dep_headers
141136
dep_headers
142137
end
143138
Base.CoreLogging._min_enabled_level[] = Logging.Debug

0 commit comments

Comments
 (0)