Skip to content

Commit fe368ad

Browse files
authored
Merge pull request #25 from cpunion/signature
auto generate function signature
2 parents f2c4b10 + d4b082e commit fe368ad

File tree

10 files changed

+725
-112
lines changed

10 files changed

+725
-112
lines changed

README.md

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ demo.launch()
258258

259259
var gr Module
260260

261-
func UpdateExamples(country string) Object {
261+
func updateExamples(country string) Object {
262262
println("country:", country)
263263
if country == "USA" {
264264
return gr.Call("Dataset", KwArgs{
@@ -280,10 +280,6 @@ func main() {
280280
Initialize()
281281
defer Finalize()
282282
gr = ImportModule("gradio")
283-
fn := CreateFunc("update_examples", UpdateExamples,
284-
"(country, /)\n--\n\nUpdate examples based on country")
285-
// Would be (in the future):
286-
// fn := FuncOf(UpdateExamples)
287283
demo := With(gr.Call("Blocks"), func(v Object) {
288284
dropdown := gr.Call("Dropdown", KwArgs{
289285
"label": "Country",
@@ -293,7 +289,7 @@ func main() {
293289
textbox := gr.Call("Textbox")
294290
examples := gr.Call("Examples", [][]string{{"Chicago"}, {"Little Rock"}, {"San Francisco"}}, textbox)
295291
dataset := examples.Attr("dataset")
296-
dropdown.Call("change", fn, dropdown, dataset)
292+
dropdown.Call("change", updateExamples, dropdown, dataset)
297293
})
298294
demo.Call("launch")
299295
}

convert.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ func From(from any) Object {
7676
return fromMap(vv).Object
7777
case reflect.Struct:
7878
return fromStruct(vv)
79+
case reflect.Func:
80+
return FuncOf(vv.Interface()).Object
7981
}
8082
panic(fmt.Errorf("unsupported type for Python: %T\n", v))
8183
}

convert_test.go

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,61 @@ func TestFromSpecialCases(t *testing.T) {
197197
t.Errorf("Object was not independent, got %d after modifying original", got)
198198
}
199199
}()
200+
201+
func() {
202+
// Test From with functions
203+
add := func(a, b int) int { return a + b }
204+
obj := From(add)
205+
206+
// Verify it's a function type
207+
if !obj.IsFunc() {
208+
t.Error("From(func) did not create Function object")
209+
}
210+
211+
fn := obj.AsFunc()
212+
213+
// Test function call
214+
result := fn.Call(5, 3)
215+
216+
if !result.IsLong() {
217+
t.Error("Function call result is not a Long")
218+
}
219+
if got := result.AsLong().Int64(); got != 8 {
220+
t.Errorf("Function call = %d, want 8", got)
221+
}
222+
}()
223+
224+
func() {
225+
// Test From with function that returns multiple values
226+
divMod := func(a, b int) (int, int) {
227+
return a / b, a % b
228+
}
229+
obj := From(divMod)
230+
if !obj.IsFunc() {
231+
t.Error("From(func) did not create Function object")
232+
}
233+
234+
fn := obj.AsFunc()
235+
236+
result := fn.Call(7, 3)
237+
238+
// Result should be a tuple with two values
239+
if !result.IsTuple() {
240+
t.Error("Multiple return value function did not return a Tuple")
241+
}
242+
243+
tuple := result.AsTuple()
244+
if tuple.Len() != 2 {
245+
t.Errorf("Expected tuple of length 2, got %d", tuple.Len())
246+
}
247+
248+
quotient := tuple.Get(0).AsLong().Int64()
249+
remainder := tuple.Get(1).AsLong().Int64()
250+
251+
if quotient != 2 || remainder != 1 {
252+
t.Errorf("Got (%d, %d), want (2, 1)", quotient, remainder)
253+
}
254+
}()
200255
}
201256

202257
func TestToValueWithCustomType(t *testing.T) {

demo/gradio/gradio.go

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ demo.launch()
2727

2828
var gr Module
2929

30-
func UpdateExamples(country string) Object {
30+
func updateExamples(country string) Object {
3131
println("country:", country)
3232
if country == "USA" {
3333
return gr.Call("Dataset", KwArgs{
@@ -49,10 +49,6 @@ func main() {
4949
Initialize()
5050
defer Finalize()
5151
gr = ImportModule("gradio")
52-
fn := CreateFunc("update_examples", UpdateExamples,
53-
"(country, /)\n--\n\nUpdate examples based on country")
54-
// Would be (in the future):
55-
// fn := FuncOf(UpdateExamples)
5652
demo := With(gr.Call("Blocks"), func(v Object) {
5753
dropdown := gr.Call("Dropdown", KwArgs{
5854
"label": "Country",
@@ -62,7 +58,7 @@ func main() {
6258
textbox := gr.Call("Textbox")
6359
examples := gr.Call("Examples", [][]string{{"Chicago"}, {"Little Rock"}, {"San Francisco"}}, textbox)
6460
dataset := examples.Attr("dataset")
65-
dropdown.Call("change", fn, dropdown, dataset)
61+
dropdown.Call("change", updateExamples, dropdown, dataset)
6662
})
6763
demo.Call("launch")
6864
}

extension.go

Lines changed: 150 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ import (
2525
"unsafe"
2626
)
2727

28+
func FuncOf(fn any) Func {
29+
return CreateFunc("", fn, "")
30+
}
31+
2832
func CreateFunc(name string, fn any, doc string) Func {
2933
m := MainModule()
3034
return m.AddMethod(name, fn, doc)
@@ -583,7 +587,24 @@ func (m Module) AddMethod(name string, fn any, doc string) Func {
583587
}
584588
}
585589
name = goNameToPythonName(name)
586-
doc = name + doc
590+
591+
hasRecv := false
592+
if t.NumIn() > 0 {
593+
firstParam := t.In(0)
594+
if firstParam.Kind() == reflect.Ptr || firstParam.Kind() == reflect.Interface {
595+
hasRecv = true
596+
}
597+
}
598+
599+
kwargsType := reflect.TypeOf(KwArgs{})
600+
hasKwArgs := false
601+
if t.NumIn() > 0 && t.In(t.NumIn()-1) == kwargsType {
602+
hasKwArgs = true
603+
}
604+
605+
sig := genSig(fn, hasRecv)
606+
fullDoc := name + sig + "\n--\n\n" + doc
607+
cDoc := C.CString(fullDoc)
587608

588609
maps := getGlobalData()
589610
meta, ok := maps.typeMetas[m.obj]
@@ -596,23 +617,25 @@ func (m Module) AddMethod(name string, fn any, doc string) Func {
596617

597618
methodId := uint(len(meta.methods))
598619

599-
methodPtr := C.wrapperMethods[methodId]
600620
cName := C.CString(name)
601-
cDoc := C.CString(doc)
602621

603622
def := (*C.PyMethodDef)(C.malloc(C.size_t(unsafe.Sizeof(C.PyMethodDef{}))))
604623
def.ml_name = cName
605-
def.ml_meth = C.PyCFunction(methodPtr)
624+
def.ml_meth = C.PyCFunction(C.wrapperMethods[methodId])
606625
def.ml_flags = C.METH_VARARGS
626+
if hasKwArgs {
627+
def.ml_flags |= C.METH_KEYWORDS
628+
def.ml_meth = C.PyCFunction(C.wrapperMethodsWithKwargs[methodId])
629+
}
607630
def.ml_doc = cDoc
608631

609632
methodMeta := &slotMeta{
610633
name: name,
611634
methodName: name,
612635
fn: fn,
613636
typ: t,
614-
doc: doc,
615-
hasRecv: false,
637+
doc: fullDoc,
638+
hasRecv: hasRecv,
616639
def: def,
617640
}
618641
meta.methods[methodId] = methodMeta
@@ -665,3 +688,124 @@ func FetchError() error {
665688

666689
return fmt.Errorf("python error: %s", C.GoString(cstr))
667690
}
691+
692+
func genSig(fn any, hasRecv bool) string {
693+
t := reflect.TypeOf(fn)
694+
if t.Kind() != reflect.Func {
695+
panic("genSig: fn must be a function")
696+
}
697+
698+
var args []string
699+
startIdx := 0
700+
if hasRecv {
701+
startIdx = 1 // skip receiver
702+
}
703+
704+
kwargsType := reflect.TypeOf(KwArgs{})
705+
hasKwArgs := false
706+
lastParamIdx := t.NumIn() - 1
707+
if lastParamIdx >= startIdx && t.In(lastParamIdx) == kwargsType {
708+
hasKwArgs = true
709+
lastParamIdx-- // don't include KwArgs in regular parameters
710+
}
711+
712+
for i := startIdx; i <= lastParamIdx; i++ {
713+
paramName := fmt.Sprintf("arg%d", i-startIdx)
714+
args = append(args, paramName)
715+
}
716+
717+
// add "/" separator only if there are parameters
718+
if len(args) > 0 {
719+
args = append(args, "/")
720+
}
721+
722+
// add "**kwargs" if there are keyword arguments
723+
if hasKwArgs {
724+
args = append(args, "**kwargs")
725+
}
726+
727+
return fmt.Sprintf("(%s)", strings.Join(args, ", "))
728+
}
729+
730+
//export wrapperMethodWithKwargs
731+
func wrapperMethodWithKwargs(self, args, kwargs *C.PyObject, methodId C.int) *C.PyObject {
732+
key := self
733+
if C.isModule(self) == 0 {
734+
key = (*C.PyObject)(unsafe.Pointer(self.ob_type))
735+
}
736+
737+
maps := getGlobalData()
738+
typeMeta, ok := maps.typeMetas[key]
739+
check(ok, fmt.Sprintf("type %v not registered", FromPy(key)))
740+
741+
methodMeta := typeMeta.methods[uint(methodId)]
742+
methodType := methodMeta.typ
743+
hasReceiver := methodMeta.hasRecv
744+
745+
expectedArgs := methodType.NumIn()
746+
if hasReceiver {
747+
expectedArgs-- // skip receiver
748+
}
749+
expectedArgs-- // skip KwArgs
750+
751+
argc := C.PyTuple_Size(args)
752+
if int(argc) != expectedArgs {
753+
SetTypeError(fmt.Errorf("method %s expects %d arguments, got %d", methodMeta.name, expectedArgs, argc))
754+
return nil
755+
}
756+
757+
goArgs := make([]reflect.Value, methodType.NumIn())
758+
argIndex := 0
759+
760+
if hasReceiver {
761+
wrapper := (*wrapperType)(unsafe.Pointer(self))
762+
receiverType := methodType.In(0)
763+
var recv reflect.Value
764+
765+
if receiverType.Kind() == reflect.Ptr {
766+
recv = reflect.ValueOf(wrapper.goObj)
767+
} else {
768+
recv = reflect.ValueOf(wrapper.goObj).Elem()
769+
}
770+
771+
goArgs[0] = recv
772+
argIndex = 1
773+
}
774+
775+
for i := 0; i < int(argc); i++ {
776+
arg := C.PySequence_GetItem(args, C.Py_ssize_t(i))
777+
argType := methodType.In(i + argIndex)
778+
argPy := FromPy(arg)
779+
goValue := reflect.New(argType).Elem()
780+
if !ToValue(argPy, goValue) {
781+
SetTypeError(fmt.Errorf("failed to convert argument %v to %v", argPy, argType))
782+
return nil
783+
}
784+
goArgs[i+argIndex] = goValue
785+
}
786+
787+
kwargsValue := make(KwArgs)
788+
if kwargs != nil {
789+
dict := newDict(kwargs)
790+
dict.Items()(func(key, value Object) bool {
791+
kwargsValue[key.String()] = value
792+
return true
793+
})
794+
}
795+
goArgs[len(goArgs)-1] = reflect.ValueOf(kwargsValue)
796+
797+
results := reflect.ValueOf(methodMeta.fn).Call(goArgs)
798+
799+
if len(results) == 0 {
800+
return None().cpyObj()
801+
}
802+
if len(results) == 1 {
803+
return From(results[0].Interface()).cpyObj()
804+
}
805+
806+
tuple := MakeTupleWithLen(len(results))
807+
for i := range results {
808+
tuple.Set(i, From(results[i].Interface()))
809+
}
810+
return tuple.cpyObj()
811+
}

0 commit comments

Comments
 (0)