diff --git a/debug_test.go b/debug_test.go index 0b1548e66f..3dbb84a58a 100644 --- a/debug_test.go +++ b/debug_test.go @@ -65,6 +65,52 @@ func TestPrintln(t *testing.T) { } } +// ------------------------------------------------------------------------------------------------- +// test printf (non regression) +type printfCircuit struct { + A, B frontend.Variable +} + +func (circuit *printfCircuit) Define(api frontend.API) error { + c := api.Add(circuit.A, circuit.B) + api.Printf("%d is the addition", c) + d := api.Mul(circuit.A, c) + api.Printf("%s %s", d, new(big.Int).SetInt64(42)) + bs := api.ToBinary(circuit.B, 10) + api.Printf("bits %s", bs[3]) + api.Printf("circuit %s", circuit) + nb := api.Mul(bs[1], 2) + api.AssertIsBoolean(nb) // this will fail + m := api.Mul(circuit.A, circuit.B) + api.Printf("m %s", m) // this should not be resolved + return nil +} + +func TestPrintf(t *testing.T) { + assert := require.New(t) + + var circuit, witness printfCircuit + witness.A = 2 + witness.B = 11 + + var expected bytes.Buffer + expected.WriteString("debug_test.go:30 > 13 is the addition\n") + expected.WriteString("debug_test.go:32 > 26 42\n") + expected.WriteString("debug_test.go:34 > bits 1\n") + expected.WriteString("debug_test.go:35 > circuit {A: 2, B: 11}\n") + expected.WriteString("debug_test.go:39 > m .*\n") + + { + trace, _ := getGroth16Trace(&circuit, &witness) + assert.Regexp(expected.String(), trace) + } + + { + trace, _ := getPlonkTrace(&circuit, &witness) + assert.Regexp(expected.String(), trace) + } +} + // ------------------------------------------------------------------------------------------------- // Div by 0 type divBy0Trace struct { diff --git a/frontend/api.go b/frontend/api.go index 5bfb761574..b1cbbf31b7 100644 --- a/frontend/api.go +++ b/frontend/api.go @@ -128,6 +128,10 @@ type API interface { // whose value will be resolved at runtime when computed by the solver Println(a ...Variable) + // Printf behaves like fmt.Printf but accepts cd.Variable as parameters + // whose value will be resolved at runtime when computed by the solver + Printf(format string, args ...Variable) + // Compiler returns the compiler object for advanced circuit development Compiler() Compiler diff --git a/frontend/cs/r1cs/api.go b/frontend/cs/r1cs/api.go index e754aa14da..bfc3d4a8b4 100644 --- a/frontend/cs/r1cs/api.go +++ b/frontend/cs/r1cs/api.go @@ -628,6 +628,68 @@ func (builder *builder) Println(a ...frontend.Variable) { builder.cs.AddLog(log) } +func (builder *builder) Printf(format string, args ...frontend.Variable) { + var log constraint.LogEntry + + // prefix log line with file.go:line + if _, file, line, ok := runtime.Caller(1); ok { + log.Caller = fmt.Sprintf("%s:%d", filepath.Base(file), line) + } + + var sbb strings.Builder + formatIndex := 0 + + // Parse the format string and match placeholders with args + for _, arg := range args { + // Search for the next format specifier + nextPercent := strings.Index(format[formatIndex:], "%") + if nextPercent == -1 { + // No more placeholders; add remaining format string and break + sbb.WriteString(format[formatIndex:]) + break + } + + // Add the part of the format string before the next % + sbb.WriteString(format[formatIndex : formatIndex+nextPercent]) + formatIndex += nextPercent + 1 // Move past % + + // Handle format specifier + if formatIndex < len(format) { + specifier := format[formatIndex] + formatIndex++ // Move past the specifier + + switch specifier { + case 's', 'd', 'f', 'x': // Supported format specifiers + if v, ok := arg.(expr.LinearExpression); ok { + assertIsSet(v) + sbb.WriteString("%" + string(specifier)) + log.ToResolve = append(log.ToResolve, builder.getLinearExpression(v)) + } else { + builder.printArg(&log, &sbb, arg) + } + default: + // Unsupported specifier; add it directly + sbb.WriteByte('%') + sbb.WriteByte(specifier) + } + } else { + // Malformed format string (ends with %) + sbb.WriteByte('%') + } + } + + // Add remaining format string after the last placeholder + if formatIndex < len(format) { + sbb.WriteString(format[formatIndex:]) + } + + // Set the format string for the log entry + log.Format = sbb.String() + + // Add the log entry to the circuit's constraint system + builder.cs.AddLog(log) +} + func (builder *builder) printArg(log *constraint.LogEntry, sbb *strings.Builder, a frontend.Variable) { leafCount, err := schema.Walk(a, tVariable, nil) diff --git a/frontend/cs/scs/api.go b/frontend/cs/scs/api.go index 3d9e202f2e..969696c33e 100644 --- a/frontend/cs/scs/api.go +++ b/frontend/cs/scs/api.go @@ -539,6 +539,37 @@ func (builder *builder) Println(a ...frontend.Variable) { builder.cs.AddLog(log) } +func (builder *builder) Printf(format string, args ...frontend.Variable) { + var log constraint.LogEntry + + // prefix log line with file.go:line + if _, file, line, ok := runtime.Caller(1); ok { + log.Caller = fmt.Sprintf("%s:%d", filepath.Base(file), line) + } + + var sbb strings.Builder + + for i, arg := range args { + if i > 0 { + sbb.WriteByte(' ') + } + if v, ok := arg.(expr.Term); ok { + + sbb.WriteString("%s") + // we set limits to the linear expression, so that the log printer + // can evaluate it before printing it + log.ToResolve = append(log.ToResolve, constraint.LinearExpression{builder.cs.MakeTerm(v.Coeff, v.VID)}) + } else { + builder.printArg(&log, &sbb, arg) + } + } + + // set format string to be used with fmt.Sprintf, once the variables are solved in the R1CS.Solve() method + log.Format = sbb.String() + + builder.cs.AddLog(log) +} + func (builder *builder) printArg(log *constraint.LogEntry, sbb *strings.Builder, a frontend.Variable) { leafCount, err := schema.Walk(a, tVariable, nil) diff --git a/test/engine.go b/test/engine.go index fd84ab4eea..e3d6e3b737 100644 --- a/test/engine.go +++ b/test/engine.go @@ -515,6 +515,25 @@ func (e *engine) print(sbb *strings.Builder, x interface{}) { } } +func (e *engine) Printf(format string, args ...frontend.Variable) { + var sbb strings.Builder + sbb.WriteString("(test.engine) ") + + // prefix log line with file.go:line + if _, file, line, ok := runtime.Caller(1); ok { + sbb.WriteString(filepath.Base(file)) + sbb.WriteByte(':') + sbb.WriteString(strconv.Itoa(line)) + sbb.WriteByte(' ') + } + + for i := 0; i < len(args); i++ { + e.print(&sbb, args[i]) + sbb.WriteByte(' ') + } + fmt.Println(sbb.String()) +} + func (e *engine) NewHint(f solver.Hint, nbOutputs int, inputs ...frontend.Variable) ([]frontend.Variable, error) { if nbOutputs <= 0 {