Skip to content

Commit 351dc4e

Browse files
goccyclaude
andauthored
Optimize WASM plugin execution with compilation caching (#363)
* optimize WASM plugin execution with compilation caching * fix WASM plugin cache key to include sha256 and return errors from Close Include opt.Sha256 in the cache key so that the same path with different hash values does not bypass hash verification. Return errors from Close methods using errors.Join instead of silently discarding them. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * address PR #363 review feedback - Use content-derived sha256 as cache key instead of user-provided value - Make wasmPluginCache thread-safe with sync.RWMutex - Add defer g.Close(ctx) in CLI entry point - Return SupportedFeatures even when result.Files is empty Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * remove counter field --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 5eeb102 commit 351dc4e

File tree

3 files changed

+139
-51
lines changed

3 files changed

+139
-51
lines changed

cmd/grpc-federation-generator/main.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ func _main(ctx context.Context, args []string, opt *option) error {
4343
protoPath = args[0]
4444
}
4545
g := generator.New(cfg)
46+
defer g.Close(ctx)
4647
var opts []generator.Option
4748
if opt.WatchMode {
4849
opts = append(opts, generator.WatchMode())

generator/generator.go

Lines changed: 43 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@ package generator
33
import (
44
"bytes"
55
"context"
6-
"crypto/sha256"
7-
"encoding/hex"
86
"errors"
97
"fmt"
108
"io"
@@ -25,7 +23,6 @@ import (
2523

2624
"github.com/mercari/grpc-federation/compiler"
2725
"github.com/mercari/grpc-federation/grpc/federation/generator"
28-
"github.com/mercari/grpc-federation/grpc/federation/generator/plugin"
2926
"github.com/mercari/grpc-federation/resolver"
3027
"github.com/mercari/grpc-federation/source"
3128
"github.com/mercari/grpc-federation/validator"
@@ -47,6 +44,7 @@ type Generator struct {
4744
postProcessHandler PostProcessHandler
4845
buildCacheMap BuildCacheMap
4946
absPathToRelativePath map[string]string
47+
wasmPluginCache *wasmPluginCache
5048
}
5149

5250
type Option func(*Generator) error
@@ -135,9 +133,18 @@ func New(cfg Config) *Generator {
135133
validator: validator.New(),
136134
importPaths: cfg.Imports,
137135
absPathToRelativePath: make(map[string]string),
136+
wasmPluginCache: newWasmPluginCache(),
138137
}
139138
}
140139

140+
// Close releases resources held by the Generator, including cached WASM plugin runtimes.
141+
func (g *Generator) Close(ctx context.Context) error {
142+
if g.wasmPluginCache != nil {
143+
return g.wasmPluginCache.Close(ctx)
144+
}
145+
return nil
146+
}
147+
141148
func (g *Generator) SetPostProcessHandler(postProcessHandler func(ctx context.Context, path string, result Result) error) {
142149
g.postProcessHandler = postProcessHandler
143150
}
@@ -182,7 +189,7 @@ func (g *Generator) Generate(ctx context.Context, protoPath string, opts ...Opti
182189
}
183190
results = append(results, result)
184191
}
185-
pluginResp, err := evalAllCodeGenerationPlugin(ctx, results, g.federationGeneratorOption)
192+
pluginResp, err := evalAllCodeGenerationPlugin(ctx, results, g.federationGeneratorOption, g.wasmPluginCache)
186193
if err != nil {
187194
return err
188195
}
@@ -359,7 +366,7 @@ func (g *Generator) setWatcher(w *Watcher) error {
359366
results = append(results, result)
360367
}
361368
results = append(results, g.otherResults(path)...)
362-
pluginResp, err := evalAllCodeGenerationPlugin(ctx, results, g.federationGeneratorOption)
369+
pluginResp, err := evalAllCodeGenerationPlugin(ctx, results, g.federationGeneratorOption, g.wasmPluginCache)
363370
if err != nil {
364371
log.Printf("failed to run code generator plugin: %+v", err)
365372
}
@@ -651,7 +658,7 @@ func (g *Generator) fileNameWithoutExt(name string) string {
651658
return name[:len(name)-len(filepath.Ext(name))]
652659
}
653660

654-
func CreateCodeGeneratorResponse(ctx context.Context, req *pluginpb.CodeGeneratorRequest) (*pluginpb.CodeGeneratorResponse, error) {
661+
func CreateCodeGeneratorResponse(ctx context.Context, req *pluginpb.CodeGeneratorRequest) (_ *pluginpb.CodeGeneratorResponse, err error) {
655662
opt, err := parseOptString(req.GetParameter())
656663
if err != nil {
657664
return nil, err
@@ -665,20 +672,33 @@ func CreateCodeGeneratorResponse(ctx context.Context, req *pluginpb.CodeGenerato
665672
fmt.Fprint(os.Stderr, validator.Format(outs))
666673
}
667674

675+
var resp pluginpb.CodeGeneratorResponse
676+
// TODO: Since we don’t currently support editions, we will comment it out.
677+
// Strictly speaking, proto3 optional is also not fully supported, but because it cannot be used together when other plugins support proto3 optional,
678+
// we have enabled it for the time being.
679+
resp.SupportedFeatures = proto.Uint64(uint64(pluginpb.CodeGeneratorResponse_FEATURE_PROTO3_OPTIONAL /*| pluginpb.CodeGeneratorResponse_FEATURE_SUPPORTS_EDITIONS*/))
680+
681+
if len(result.Files) == 0 {
682+
return &resp, nil
683+
}
684+
685+
cache := newWasmPluginCache()
686+
defer func() {
687+
if closeErr := cache.Close(ctx); closeErr != nil {
688+
err = errors.Join(err, closeErr)
689+
}
690+
}()
668691
pluginResp, err := evalAllCodeGenerationPlugin(ctx, []*ProtoFileResult{
669692
{
670693
Type: ProtocAction,
671694
ProtoPath: "",
672695
FederationFiles: result.Files,
673696
},
674-
}, opt)
697+
}, opt, cache)
675698
if err != nil {
676699
return nil, err
677700
}
678-
679-
var resp pluginpb.CodeGeneratorResponse
680701
resp.File = append(resp.File, pluginResp.File...)
681-
resp.SupportedFeatures = proto.Uint64(uint64(pluginpb.CodeGeneratorResponse_FEATURE_PROTO3_OPTIONAL /*| pluginpb.CodeGeneratorResponse_FEATURE_SUPPORTS_EDITIONS*/))
682702
for _, file := range result.Files {
683703
out, err := NewCodeGenerator().Generate(file)
684704
if err != nil {
@@ -696,61 +716,38 @@ func CreateCodeGeneratorResponse(ctx context.Context, req *pluginpb.CodeGenerato
696716
return &resp, nil
697717
}
698718

699-
func evalAllCodeGenerationPlugin(ctx context.Context, results []*ProtoFileResult, opt *CodeGeneratorOption) (*pluginpb.CodeGeneratorResponse, error) {
719+
func evalAllCodeGenerationPlugin(ctx context.Context, results []*ProtoFileResult, opt *CodeGeneratorOption, cache *wasmPluginCache) (*pluginpb.CodeGeneratorResponse, error) {
700720
if len(results) == 0 {
701721
return &pluginpb.CodeGeneratorResponse{}, nil
702722
}
703723
if opt == nil || len(opt.Plugins) == 0 {
704724
return &pluginpb.CodeGeneratorResponse{}, nil
705725
}
726+
706727
var resp pluginpb.CodeGeneratorResponse
707728
for _, result := range results {
708-
pluginFiles := make([]*plugin.ProtoCodeGeneratorResponse_File, 0, len(result.Files))
709-
for _, file := range result.Files {
710-
fileBytes, err := proto.Marshal(file)
729+
for _, p := range opt.Plugins {
730+
wp, err := cache.getOrCreate(ctx, p)
711731
if err != nil {
712732
return nil, err
713733
}
714-
var pluginFile plugin.ProtoCodeGeneratorResponse_File
715-
if err := proto.Unmarshal(fileBytes, &pluginFile); err != nil {
716-
return nil, err
717-
}
718-
pluginFiles = append(pluginFiles, &pluginFile)
719-
}
720-
genReq := generator.CreateCodeGeneratorRequest(&generator.CodeGeneratorRequestConfig{
721-
Type: generator.ActionType(result.Type),
722-
ProtoPath: result.ProtoPath,
723-
Files: pluginFiles,
724-
GRPCFederationFiles: result.FederationFiles,
725-
OutputFilePathConfig: opt.Path,
726-
})
727-
encodedGenReq, err := proto.Marshal(genReq)
728-
if err != nil {
729-
return nil, err
730-
}
731-
for _, plugin := range opt.Plugins {
732-
wasmFile, err := os.ReadFile(plugin.Path)
734+
genReq := generator.CreateCodeGeneratorRequest(&generator.CodeGeneratorRequestConfig{
735+
ProtoPath: result.ProtoPath,
736+
GRPCFederationFiles: result.FederationFiles,
737+
OutputFilePathConfig: opt.Path,
738+
})
739+
encodedGenReq, err := proto.Marshal(genReq)
733740
if err != nil {
734-
return nil, fmt.Errorf("grpc-federation: failed to read plugin file: %s: %w", plugin.Path, err)
735-
}
736-
if plugin.Sha256 != "" {
737-
hash := sha256.Sum256(wasmFile)
738-
gotHash := hex.EncodeToString(hash[:])
739-
if plugin.Sha256 != gotHash {
740-
return nil, fmt.Errorf(
741-
`grpc-federation: expected plugin sha256 value is [%s] but got [%s]`,
742-
plugin.Sha256,
743-
gotHash,
744-
)
745-
}
741+
return nil, err
746742
}
747-
pluginRes, err := evalCodeGeneratorPlugin(ctx, wasmFile, bytes.NewBuffer(encodedGenReq))
743+
pluginRes, err := wp.Execute(ctx, bytes.NewBuffer(encodedGenReq))
748744
if err != nil {
749745
return nil, err
750746
}
751747
resp.File = append(resp.File, pluginRes.File...)
752748
}
753749
}
750+
754751
return &resp, nil
755752
}
756753

generator/wasm.go

Lines changed: 95 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,37 +3,63 @@ package generator
33
import (
44
"bytes"
55
"context"
6+
"crypto/sha256"
7+
"encoding/hex"
8+
"errors"
69
"fmt"
710
"io"
811
"os"
912
"path/filepath"
13+
"sync"
1014

1115
"github.com/tetratelabs/wazero"
1216
"github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1"
1317
"google.golang.org/protobuf/proto"
1418
"google.golang.org/protobuf/types/pluginpb"
1519
)
1620

17-
func evalCodeGeneratorPlugin(ctx context.Context, pluginFile []byte, req io.Reader) (*pluginpb.CodeGeneratorResponse, error) {
21+
// wasmPlugin holds a compiled WASM module and its runtime, allowing
22+
// the module to be instantiated multiple times without recompilation.
23+
// Execute is safe for concurrent use.
24+
type wasmPlugin struct {
25+
runtime wazero.Runtime
26+
compiled wazero.CompiledModule
27+
}
28+
29+
func newWasmPlugin(ctx context.Context, wasmBytes []byte) (*wasmPlugin, error) {
1830
runtimeCfg := wazero.NewRuntimeConfigInterpreter()
1931
if cache := getCompilationCache(); cache != nil {
2032
runtimeCfg = runtimeCfg.WithCompilationCache(cache)
2133
}
22-
2334
r := wazero.NewRuntimeWithConfig(ctx, runtimeCfg)
24-
2535
wasi_snapshot_preview1.MustInstantiate(ctx, r)
26-
buf := bytes.NewBuffer([]byte{})
2736

37+
compiled, err := r.CompileModule(ctx, wasmBytes)
38+
if err != nil {
39+
r.Close(ctx)
40+
return nil, fmt.Errorf("grpc-federation: failed to compile code-generator plugin: %w", err)
41+
}
42+
return &wasmPlugin{
43+
runtime: r,
44+
compiled: compiled,
45+
}, nil
46+
}
47+
48+
func (p *wasmPlugin) Execute(ctx context.Context, req io.Reader) (*pluginpb.CodeGeneratorResponse, error) {
49+
buf := new(bytes.Buffer)
2850
modCfg := wazero.NewModuleConfig().
2951
WithFSConfig(wazero.NewFSConfig().WithDirMount(".", "/")).
3052
WithStdin(req).
3153
WithStdout(buf).
3254
WithStderr(os.Stderr).
3355
WithArgs("wasi")
34-
if _, err := r.InstantiateWithConfig(ctx, pluginFile, modCfg); err != nil {
56+
57+
mod, err := p.runtime.InstantiateModule(ctx, p.compiled, modCfg)
58+
if err != nil {
3559
return nil, fmt.Errorf("grpc-federation: failed to instantiate code-generator plugin: %w", err)
3660
}
61+
mod.Close(ctx)
62+
3763
var res pluginpb.CodeGeneratorResponse
3864
resBytes := buf.Bytes()
3965
if len(resBytes) != 0 {
@@ -44,6 +70,70 @@ func evalCodeGeneratorPlugin(ctx context.Context, pluginFile []byte, req io.Read
4470
return &res, nil
4571
}
4672

73+
func (p *wasmPlugin) Close(ctx context.Context) error {
74+
return p.runtime.Close(ctx)
75+
}
76+
77+
// wasmPluginCache caches compiled WASM plugins so that the expensive
78+
// compilation step is performed only once per plugin path.
79+
type wasmPluginCache struct {
80+
mu sync.RWMutex
81+
plugins map[string]*wasmPlugin
82+
}
83+
84+
func newWasmPluginCache() *wasmPluginCache {
85+
return &wasmPluginCache{plugins: make(map[string]*wasmPlugin)}
86+
}
87+
88+
func (c *wasmPluginCache) getOrCreate(ctx context.Context, opt *WasmPluginOption) (*wasmPlugin, error) {
89+
wasmFile, err := os.ReadFile(opt.Path)
90+
if err != nil {
91+
return nil, fmt.Errorf("grpc-federation: failed to read plugin file: %s: %w", opt.Path, err)
92+
}
93+
hash := sha256.Sum256(wasmFile)
94+
gotHash := hex.EncodeToString(hash[:])
95+
if opt.Sha256 != "" && opt.Sha256 != gotHash {
96+
return nil, fmt.Errorf(
97+
`grpc-federation: expected plugin sha256 value is [%s] but got [%s]`,
98+
opt.Sha256,
99+
gotHash,
100+
)
101+
}
102+
cacheKey := opt.Path + ":" + gotHash
103+
104+
c.mu.RLock()
105+
if wp, ok := c.plugins[cacheKey]; ok {
106+
c.mu.RUnlock()
107+
return wp, nil
108+
}
109+
c.mu.RUnlock()
110+
111+
c.mu.Lock()
112+
defer c.mu.Unlock()
113+
114+
// Double-check after acquiring write lock.
115+
if wp, ok := c.plugins[cacheKey]; ok {
116+
return wp, nil
117+
}
118+
wp, err := newWasmPlugin(ctx, wasmFile)
119+
if err != nil {
120+
return nil, err
121+
}
122+
c.plugins[cacheKey] = wp
123+
return wp, nil
124+
}
125+
126+
func (c *wasmPluginCache) Close(ctx context.Context) error {
127+
c.mu.Lock()
128+
defer c.mu.Unlock()
129+
130+
errs := make([]error, 0, len(c.plugins))
131+
for _, wp := range c.plugins {
132+
errs = append(errs, wp.Close(ctx))
133+
}
134+
return errors.Join(errs...)
135+
}
136+
47137
func getCompilationCache() wazero.CompilationCache {
48138
tmpDir := os.TempDir()
49139
if tmpDir == "" {

0 commit comments

Comments
 (0)