@@ -329,10 +329,7 @@ func transformType(
329329 let text = name. text
330330 let isRaw = isRawPointerType ( text: text)
331331 if isRaw && !isSizedBy {
332- throw DiagnosticError ( " raw pointers only supported for SizedBy " , node: name)
333- }
334- if !isRaw && isSizedBy {
335- throw DiagnosticError ( " SizedBy only supported for raw pointers " , node: name)
332+ throw DiagnosticError ( " void pointers not supported for countedBy " , node: name)
336333 }
337334
338335 guard let kind: Mutability = getPointerMutability ( text: text) else {
@@ -375,6 +372,33 @@ func isMutablePointerType(_ type: TypeSyntax) -> Bool {
375372 }
376373}
377374
375+ func getPointeeType( _ type: TypeSyntax ) -> TypeSyntax ? {
376+ if let optType = type. as ( OptionalTypeSyntax . self) {
377+ return getPointeeType ( optType. wrappedType)
378+ }
379+ if let impOptType = type. as ( ImplicitlyUnwrappedOptionalTypeSyntax . self) {
380+ return getPointeeType ( impOptType. wrappedType)
381+ }
382+ if let attrType = type. as ( AttributedTypeSyntax . self) {
383+ return getPointeeType ( attrType. baseType)
384+ }
385+
386+ guard let idType = type. as ( IdentifierTypeSyntax . self) else {
387+ return nil
388+ }
389+ let text = idType. name. text
390+ if text != " UnsafePointer " && text != " UnsafeMutablePointer " {
391+ return nil
392+ }
393+ guard let x = idType. genericArgumentClause else {
394+ return nil
395+ }
396+ guard let y = x. arguments. first else {
397+ return nil
398+ }
399+ return y. argument. as ( TypeSyntax . self)
400+ }
401+
378402protocol BoundsCheckedThunkBuilder {
379403 func buildFunctionCall( _ pointerArgs: [ Int : ExprSyntax ] ) throws -> ExprSyntax
380404 // buildBasicBoundsChecks creates a variable with the same name as the parameter it replaced,
@@ -652,6 +676,7 @@ extension PointerBoundsThunkBuilder {
652676 return try transformType ( oldType, generateSpan, isSizedBy, isParameter)
653677 }
654678 }
679+
655680 var countLabel : String {
656681 return isSizedBy && generateSpan ? " byteCount " : " count "
657682 }
@@ -830,7 +855,7 @@ struct CountedOrSizedPointerThunkBuilder: ParamBoundsThunkBuilder, PointerBounds
830855 var args = argOverrides
831856 let argExpr = ExprSyntax ( " \( unwrappedName) .baseAddress " )
832857 assert ( args [ index] == nil )
833- args [ index] = try castPointerToOpaquePointer ( unwrapIfNonnullable ( argExpr) )
858+ args [ index] = try castPointerToTargetType ( unwrapIfNonnullable ( argExpr) )
834859 let call = try base. buildFunctionCall ( args)
835860 let ptrRef = unwrapIfNullable ( " \( name) " )
836861
@@ -875,11 +900,16 @@ struct CountedOrSizedPointerThunkBuilder: ParamBoundsThunkBuilder, PointerBounds
875900 return type
876901 }
877902
878- func castPointerToOpaquePointer ( _ baseAddress: ExprSyntax ) throws -> ExprSyntax {
903+ func castPointerToTargetType ( _ baseAddress: ExprSyntax ) throws -> ExprSyntax {
879904 let type = peelOptionalType ( getParam ( signature, index) . type)
880905 if type. canRepresentBasicType ( type: OpaquePointer . self) {
881906 return ExprSyntax ( " OpaquePointer( \( baseAddress) ) " )
882907 }
908+ if isSizedBy {
909+ if let pointeeType = getPointeeType ( type) {
910+ return " \( baseAddress) .assumingMemoryBound(to: \( pointeeType) .self) "
911+ }
912+ }
883913 return baseAddress
884914 }
885915
@@ -911,7 +941,7 @@ struct CountedOrSizedPointerThunkBuilder: ParamBoundsThunkBuilder, PointerBounds
911941 return unwrappedCall
912942 }
913943
914- args [ index] = try castPointerToOpaquePointer ( getPointerArg ( ) )
944+ args [ index] = try castPointerToTargetType ( getPointerArg ( ) )
915945 return try base. buildFunctionCall ( args)
916946 }
917947}
0 commit comments