@@ -151,11 +151,22 @@ func getPointerMutability(text: String) -> Mutability? {
151151 case " UnsafeMutablePointer " : return . Mutable
152152 case " UnsafeRawPointer " : return . Immutable
153153 case " UnsafeMutableRawPointer " : return . Mutable
154+ case " OpaquePointer " : return . Immutable
154155 default :
155156 return nil
156157 }
157158}
158159
160+ func isRawPointerType( text: String ) -> Bool {
161+ switch text {
162+ case " UnsafeRawPointer " : return true
163+ case " UnsafeMutableRawPointer " : return true
164+ case " OpaquePointer " : return true
165+ default :
166+ return false
167+ }
168+ }
169+
159170func getSafePointerName( mut: Mutability , generateSpan: Bool , isRaw: Bool ) -> TokenSyntax {
160171 switch ( mut, generateSpan, isRaw) {
161172 case ( . Immutable, true , true ) : return " RawSpan "
@@ -180,9 +191,13 @@ func transformType(_ prev: TypeSyntax, _ variant: Variant, _ isSizedBy: Bool) th
180191 }
181192 let name = try getTypeName ( prev)
182193 let text = name. text
183- if !isSizedBy && ( text == " UnsafeRawPointer " || text == " UnsafeMutableRawPointer " ) {
194+ let isRaw = isRawPointerType ( text: text)
195+ if isRaw && !isSizedBy {
184196 throw DiagnosticError ( " raw pointers only supported for SizedBy " , node: name)
185197 }
198+ if !isRaw && isSizedBy {
199+ throw DiagnosticError ( " SizedBy only supported for raw pointers " , node: name)
200+ }
186201
187202 guard let kind: Mutability = getPointerMutability ( text: text) else {
188203 throw DiagnosticError (
@@ -390,7 +405,7 @@ struct CountedOrSizedPointerThunkBuilder: PointerBoundsThunkBuilder {
390405 var args = argOverrides
391406 let argExpr = ExprSyntax ( " \( unwrappedName) .baseAddress " )
392407 assert ( args [ index] == nil )
393- args [ index] = unwrapIfNonnullable ( argExpr)
408+ args [ index] = try castPointerToOpaquePointer ( unwrapIfNonnullable ( argExpr) )
394409 let call = try base. buildFunctionCall ( args, variant)
395410 let ptrRef = unwrapIfNullable ( ExprSyntax ( DeclReferenceExprSyntax ( baseName: name) ) )
396411
@@ -412,7 +427,26 @@ struct CountedOrSizedPointerThunkBuilder: PointerBoundsThunkBuilder {
412427 return ExprSyntax ( " \( name) . \( raw: countName) " )
413428 }
414429
415- func getPointerArg( ) -> ExprSyntax {
430+ func peelOptionalType( _ type: TypeSyntax ) -> TypeSyntax {
431+ if let optType = type. as ( OptionalTypeSyntax . self) {
432+ return optType. wrappedType
433+ }
434+ if let impOptType = type. as ( ImplicitlyUnwrappedOptionalTypeSyntax . self) {
435+ return impOptType. wrappedType
436+ }
437+ return type
438+ }
439+
440+ func castPointerToOpaquePointer( _ baseAddress: ExprSyntax ) throws -> ExprSyntax {
441+ let i = try getParameterIndexForParamName ( signature. parameterClause. parameters, name)
442+ let type = peelOptionalType ( getParam ( signature, i) . type)
443+ if type. canRepresentBasicType ( type: OpaquePointer . self) {
444+ return ExprSyntax ( " OpaquePointer( \( baseAddress) ) " )
445+ }
446+ return baseAddress
447+ }
448+
449+ func getPointerArg( ) throws -> ExprSyntax {
416450 if nullable {
417451 return ExprSyntax ( " \( name) ?.baseAddress " )
418452 }
@@ -450,7 +484,7 @@ struct CountedOrSizedPointerThunkBuilder: PointerBoundsThunkBuilder {
450484 return unwrappedCall
451485 }
452486
453- args [ index] = getPointerArg ( )
487+ args [ index] = try castPointerToOpaquePointer ( getPointerArg ( ) )
454488 return try base. buildFunctionCall ( args, variant)
455489 }
456490}
@@ -499,22 +533,28 @@ func getOptionalArgumentByName(_ argumentList: LabeledExprListSyntax, _ name: St
499533 } ) ? . expression
500534}
501535
502- func getParameterIndexForDeclRef (
503- _ parameterList: FunctionParameterListSyntax , _ ref : DeclReferenceExprSyntax
536+ func getParameterIndexForParamName (
537+ _ parameterList: FunctionParameterListSyntax , _ tok : TokenSyntax
504538) throws -> Int {
505- let name = ref . baseName . text
539+ let name = tok . text
506540 guard
507541 let index = parameterList. enumerated ( ) . first ( where: {
508542 ( _: Int , param: FunctionParameterSyntax ) in
509543 let paramenterName = param. secondName ?? param. firstName
510544 return paramenterName. trimmed. text == name
511545 } ) ? . offset
512546 else {
513- throw DiagnosticError ( " no parameter with name ' \( name) ' in ' \( parameterList) ' " , node: ref )
547+ throw DiagnosticError ( " no parameter with name ' \( name) ' in ' \( parameterList) ' " , node: tok )
514548 }
515549 return index
516550}
517551
552+ func getParameterIndexForDeclRef(
553+ _ parameterList: FunctionParameterListSyntax , _ ref: DeclReferenceExprSyntax
554+ ) throws -> Int {
555+ return try getParameterIndexForParamName ( ( parameterList) , ref. baseName)
556+ }
557+
518558/// A macro that adds safe(r) wrappers for functions with unsafe pointer types.
519559/// Depends on bounds, escapability and lifetime information for each pointer.
520560/// Intended to map to C attributes like __counted_by, __ended_by and __no_escape,
0 commit comments