Skip to content

Commit 361e6d4

Browse files
authored
Merge pull request #40 from liuliu/liuliu/python-function
Allow defining Swift functions to be called from Python with revisions
2 parents 839ef68 + 108ac70 commit 361e6d4

File tree

2 files changed

+159
-1
lines changed

2 files changed

+159
-1
lines changed

PythonKit/Python.swift

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1400,6 +1400,7 @@ extension PythonObject : Sequence {
14001400
}
14011401
}
14021402

1403+
14031404
//===----------------------------------------------------------------------===//
14041405
// `ExpressibleByLiteral` conformances
14051406
//===----------------------------------------------------------------------===//
@@ -1517,3 +1518,145 @@ public struct PythonBytes : PythonConvertible, ConvertibleFromPython, Hashable {
15171518
}
15181519
}
15191520
}
1521+
1522+
//===----------------------------------------------------------------------===//
1523+
// PythonFunction - create functions in Swift that can be called from Python
1524+
//===----------------------------------------------------------------------===//
1525+
1526+
/// Create functions in Swift that can be called from Python
1527+
///
1528+
/// Example:
1529+
///
1530+
/// The Python code `map(lambda(x: x * 2), [10, 12, 14])` would be written as:
1531+
///
1532+
/// Python.map(PythonFunction { x in x * 2 }, [10, 12, 14]) // [20, 24, 28]
1533+
///
1534+
final class PyFunction {
1535+
private var callSwiftFunction: (_ argumentsTuple: PythonObject) throws -> PythonConvertible
1536+
init(_ callSwiftFunction: @escaping (_ argumentsTuple: PythonObject) throws -> PythonConvertible) {
1537+
self.callSwiftFunction = callSwiftFunction
1538+
}
1539+
func callAsFunction(_ argumentsTuple: PythonObject) throws -> PythonConvertible {
1540+
try callSwiftFunction(argumentsTuple)
1541+
}
1542+
}
1543+
1544+
public struct PythonFunction {
1545+
/// Called directly by the Python C API
1546+
private var function: PyFunction
1547+
1548+
public init(_ fn: @escaping (PythonObject) throws -> PythonConvertible) {
1549+
function = PyFunction { argumentsAsTuple in
1550+
return try fn(argumentsAsTuple[0])
1551+
}
1552+
}
1553+
1554+
/// For cases where the Swift function should accept more (or less) than one parameter, accept an ordered array of all arguments instead
1555+
public init(_ fn: @escaping ([PythonObject]) throws -> PythonConvertible) {
1556+
function = PyFunction { argumentsAsTuple in
1557+
return try fn(argumentsAsTuple.map { $0 })
1558+
}
1559+
}
1560+
1561+
}
1562+
1563+
extension PythonFunction : PythonConvertible {
1564+
public var pythonObject: PythonObject {
1565+
// Ensure Python is initialized, and check for version match.
1566+
let versionMajor = Python.versionInfo.major
1567+
let versionMinor = Python.versionInfo.minor
1568+
guard (versionMajor == 3 && versionMinor >= 1) || versionMajor > 3 else {
1569+
fatalError("PythonFunction only supports Python 3.1 and above.")
1570+
}
1571+
1572+
let funcPointer = Unmanaged.passRetained(function).toOpaque()
1573+
let capsulePointer = PyCapsule_New(funcPointer, nil, { capsulePointer in
1574+
let funcPointer = PyCapsule_GetPointer(capsulePointer, nil)
1575+
Unmanaged<PyFunction>.fromOpaque(funcPointer).release()
1576+
})
1577+
1578+
let pyFuncPointer = PyCFunction_New(
1579+
PythonFunction.sharedMethodDefinition,
1580+
capsulePointer
1581+
)
1582+
1583+
return PythonObject(consuming: pyFuncPointer)
1584+
}
1585+
}
1586+
1587+
fileprivate extension PythonFunction {
1588+
static let sharedFunctionName: UnsafePointer<Int8> = {
1589+
let name: StaticString = "pythonkit_swift_function"
1590+
// `utf8Start` is a property of StaticString, thus, it has a stable pointer.
1591+
return UnsafeRawPointer(name.utf8Start).assumingMemoryBound(to: Int8.self)
1592+
}()
1593+
1594+
static let sharedMethodDefinition: UnsafeMutablePointer<PyMethodDef> = {
1595+
/// The standard calling convention. See Python C API docs
1596+
let METH_VARARGS = 1 as Int32
1597+
1598+
let pointer = UnsafeMutablePointer<PyMethodDef>.allocate(capacity: 1)
1599+
pointer.pointee = PyMethodDef(
1600+
ml_name: PythonFunction.sharedFunctionName,
1601+
ml_meth: PythonFunction.sharedMethodImplementation,
1602+
ml_flags: METH_VARARGS,
1603+
ml_doc: nil
1604+
)
1605+
1606+
return pointer
1607+
}()
1608+
1609+
private static let sharedMethodImplementation: @convention(c) (PyObjectPointer?, PyObjectPointer?) -> PyObjectPointer? = { context, argumentsPointer in
1610+
guard let argumentsPointer = argumentsPointer, let capsulePointer = context else {
1611+
return nil
1612+
}
1613+
1614+
let funcPointer = PyCapsule_GetPointer(capsulePointer, nil)
1615+
let function = Unmanaged<PyFunction>.fromOpaque(funcPointer).takeUnretainedValue()
1616+
1617+
do {
1618+
let argumentsAsTuple = PythonObject(consuming: argumentsPointer)
1619+
return try function(argumentsAsTuple).ownedPyObject
1620+
} catch {
1621+
PythonFunction.setPythonError(swiftError: error)
1622+
return nil // This must only be `nil` if an exception has been set
1623+
}
1624+
}
1625+
1626+
private static func setPythonError(swiftError: Error) {
1627+
if let pythonObject = swiftError as? PythonObject {
1628+
if Bool(Python.isinstance(pythonObject, Python.BaseException))! {
1629+
// We are an instance of an Exception class type. Set the exception class to the object's type:
1630+
PyErr_SetString(Python.type(pythonObject).ownedPyObject, pythonObject.description)
1631+
} else {
1632+
// Assume an actual class type was thrown (rather than an instance)
1633+
// Crashes if it was neither a subclass of BaseException nor an instance of one.
1634+
//
1635+
// We *could* check to see whether `pythonObject` is a class here and fall back
1636+
// to the default case of setting a generic Exception, below, but we also want
1637+
// people to write valid code.
1638+
PyErr_SetString(pythonObject.ownedPyObject, pythonObject.description)
1639+
}
1640+
} else {
1641+
// Make a generic Python Exception based on the Swift Error:
1642+
PyErr_SetString(Python.Exception.ownedPyObject, "\(type(of: swiftError)) raised in Swift: \(swiftError)")
1643+
}
1644+
}
1645+
}
1646+
1647+
extension PythonObject: Error {}
1648+
1649+
// From Python's C Headers:
1650+
struct PyMethodDef {
1651+
/// The name of the built-in function/method
1652+
public var ml_name: UnsafePointer<Int8>
1653+
1654+
/// The C function that implements it
1655+
public var ml_meth: @convention(c) (PyObjectPointer?, PyObjectPointer?) -> PyObjectPointer?
1656+
1657+
/// Combination of METH_xxx flags, which mostly describe the args expected by the C func
1658+
public var ml_flags: Int32
1659+
1660+
/// The __doc__ attribute, or NULL
1661+
public var ml_doc: UnsafePointer<Int8>?
1662+
}

PythonKit/PythonLibrary+Symbols.swift

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,14 @@
2020

2121
@usableFromInline
2222
typealias PyObjectPointer = UnsafeMutableRawPointer
23+
typealias PyMethodDefPointer = UnsafeMutableRawPointer
2324
typealias PyCCharPointer = UnsafePointer<Int8>
2425
typealias PyBinaryOperation =
2526
@convention(c) (PyObjectPointer?, PyObjectPointer?) -> PyObjectPointer?
2627
typealias PyUnaryOperation =
27-
@convention(c) (PyObjectPointer?) -> PyObjectPointer?
28+
@convention(c) (PyObjectPointer?) -> PyObjectPointer?
29+
typealias PyCapsuleDestructor =
30+
@convention(c) (PyObjectPointer?) -> Void
2831

2932
let Py_LT: Int32 = 0
3033
let Py_LE: Int32 = 1
@@ -56,6 +59,18 @@ let PyEval_GetBuiltins: @convention(c) () -> PyObjectPointer =
5659
let PyRun_SimpleString: @convention(c) (PyCCharPointer) -> Void =
5760
PythonLibrary.loadSymbol(name: "PyRun_SimpleString")
5861

62+
let PyCFunction_New: @convention(c) (PyMethodDefPointer, UnsafeMutableRawPointer) -> PyObjectPointer =
63+
PythonLibrary.loadSymbol(name: "PyCFunction_New")
64+
65+
let PyCapsule_New: @convention(c) (UnsafeMutableRawPointer, UnsafePointer<CChar>?, PyCapsuleDestructor) -> PyObjectPointer =
66+
PythonLibrary.loadSymbol(name: "PyCapsule_New")
67+
68+
let PyCapsule_GetPointer: @convention(c) (PyObjectPointer?, UnsafePointer<CChar>?) -> UnsafeMutableRawPointer =
69+
PythonLibrary.loadSymbol(name: "PyCapsule_GetPointer")
70+
71+
let PyErr_SetString: @convention(c) (PyObjectPointer, UnsafePointer<CChar>?) -> Void =
72+
PythonLibrary.loadSymbol(name: "PyErr_SetString")
73+
5974
let PyErr_Occurred: @convention(c) () -> PyObjectPointer? =
6075
PythonLibrary.loadSymbol(name: "PyErr_Occurred")
6176

0 commit comments

Comments
 (0)