@@ -12,22 +12,48 @@ import { basename, normalize as normalizePath } from "path";
1212
1313import { findPathsToCustomField , getCustomType } from "../src/encoding/customTypes/utils.ts" ;
1414
15- runNodeJs (
16- createEcmaScriptPlugin ( {
17- name : "protoc-gen-customtype-patches" ,
18- version : "v1" ,
19- generateTs,
20- } ) ,
21- ) ;
15+ export interface PluginOptions {
16+ /**
17+ * if true, we will patch the whole tree of the message type, starting from the custom field type and up to the root
18+ * in case of patching ts-proto generated types which has self-references, we need to patch only leaf level
19+ * @default false
20+ */
21+ patchWholeTree: boolean ;
22+ }
23+
24+ runNodeJs ( createEcmaScriptPlugin < PluginOptions > ( { name : "protoc-gen-customtype-patches" , version : "v1" , parseOptions, generateTs } ) ) ;
2225
2326const PROTO_PATH = "../protos" ;
24- function generateTs ( schema : Schema ) : void {
27+
28+ function parseOptions ( rawOptions : Array < {
29+ key : string ;
30+ value : string ;
31+ } > ) : PluginOptions {
32+ const options : PluginOptions = {
33+ patchWholeTree : false ,
34+ } ;
35+
36+ for ( const { key, value } of rawOptions ) {
37+ if ( key === "patch_whole_tree" ) {
38+ options . patchWholeTree = value === "true" ;
39+ }
40+ }
41+
42+ return options ;
43+ }
44+
45+ function generateTs ( schema : Schema < PluginOptions > ) : void {
2546 const allPaths : DescField [ ] [ ] = [ ] ;
2647
2748 schema . files . forEach ( ( file ) => {
2849 file . messages . forEach ( ( message ) => {
2950 const paths = findPathsToCustomField ( message , ( ) => true ) ;
30- allPaths . push ( ...paths ) ;
51+ if ( schema . options . patchWholeTree ) {
52+ allPaths . push ( ...paths ) ;
53+ } else {
54+ const leaves = paths . map ( ( path ) => path . slice ( - 1 ) ) ;
55+ allPaths . push ( ...leaves ) ;
56+ }
3157 } ) ;
3258 } ) ;
3359 if ( ! allPaths . length ) {
@@ -100,6 +126,24 @@ function generateTs(schema: Schema): void {
100126 patchesFile . print ( `const p = {\n${ indent ( patches . join ( ",\n" ) ) } \n};\n` ) ;
101127 patchesFile . print ( `export const patches = p;` ) ;
102128
129+ const patchesTypeFileName = fileName . replace ( "CustomTypePatches" , "PatchMessage" ) ;
130+ const patchTypeFile = schema . generateFile ( patchesTypeFileName ) ;
131+ patchTypeFile . print ( `import { patches } from "./${ fileName } ";` ) ;
132+ patchTypeFile . print ( `import type { MessageDesc } from "../../sdk/client/types.ts";` ) ;
133+ patchTypeFile . print ( `export const patched = <T extends MessageDesc>(messageDesc: T): T => {` ) ;
134+ patchTypeFile . print ( ` const patchMessage = patches[messageDesc.$type as keyof typeof patches] as any;` ) ;
135+ patchTypeFile . print ( ` if (!patchMessage) return messageDesc;` ) ;
136+ patchTypeFile . print ( ` return {` ) ;
137+ patchTypeFile . print ( ` ...messageDesc,` ) ;
138+ patchTypeFile . print ( ` encode(message, writer) {` ) ;
139+ patchTypeFile . print ( ` return messageDesc.encode(patchMessage(message, 'encode'), writer);` ) ;
140+ patchTypeFile . print ( ` },` ) ;
141+ patchTypeFile . print ( ` decode(input, length) {` ) ;
142+ patchTypeFile . print ( ` return patchMessage(messageDesc.decode(input, length), 'decode');` ) ;
143+ patchTypeFile . print ( ` },` ) ;
144+ patchTypeFile . print ( ` };` ) ;
145+ patchTypeFile . print ( `};` ) ;
146+
103147 const testsFile = schema . generateFile ( fileName . replace ( / \. t s $ / , ".spec.ts" ) ) ;
104148 generateTests ( basename ( fileName ) , testsFile , messageToCustomFields ) ;
105149}
@@ -185,13 +229,13 @@ function generateTests(fileName: string, testsFile: GeneratedFile, messageToCust
185229 testsFile. print ( `import { expect, describe, it } from "@jest/globals";` ) ;
186230 testsFile . print ( `import { patches } from "./${ basename ( fileName ) } ";` ) ;
187231 testsFile . print ( `import { generateMessage, type MessageSchema } from "@test/helpers/generateMessage";` ) ;
188- testsFile.print(` import type { TypePatches } from "../../sdk/client/applyPatches .ts" ; `);
232+ testsFile . print ( `import type { TypePatches } from "../../sdk/client/types .ts";` ) ;
189233 testsFile . print ( "" ) ;
190234 testsFile . print ( `const messageTypes: Record<string, MessageSchema> = {` ) ;
191235 for ( const [ message , fields ] of messageToCustomFields . entries ( ) ) {
192236 testsFile . print ( ` "${ message . typeName } ": {` ) ;
193237 testsFile . print ( ` type: ` , testsFile . import ( message . name , `${ PROTO_PATH } /${ message . file . name } .ts` ) , `,` ) ;
194- testsFile . print ( ` fields: [` , ...Array . from ( fields , f => serializeField ( f , testsFile ) ) , `],` ) ;
238+ testsFile . print ( ` fields: [` , ...Array . from ( fields , ( f ) => serializeField ( f , testsFile ) ) , `],` ) ;
195239 testsFile . print ( ` },` ) ;
196240 }
197241 testsFile . print ( `};` ) ;
@@ -236,7 +280,7 @@ function serializeField(f: DescField, file: GeneratedFile): Printable {
236280 field . push ( `scalarType: ${ f . scalar } ,` ) ;
237281 }
238282 if ( f . fieldKind === "enum" ) {
239- field . push ( `enum: ` , JSON . stringify ( f . enum . values . map ( v => v . localName ) ) , `,` ) ;
283+ field . push ( `enum: ` , JSON . stringify ( f . enum . values . map ( ( v ) => v . localName ) ) , `,` ) ;
240284 }
241285 if ( getCustomType ( f ) ) {
242286 field . push ( `customType: "${ getCustomType ( f ) ! . shortName } ",` ) ;
@@ -246,10 +290,10 @@ function serializeField(f: DescField, file: GeneratedFile): Printable {
246290 }
247291 if ( f . message ) {
248292 field . push ( `message: {fields: [` ,
249- ...f . message . fields . map ( nf => serializeField ( nf , file ) ) ,
293+ ...f . message . fields . map ( ( nf ) => serializeField ( nf , file ) ) ,
250294 `],` ,
251295 `type: ` , file . import ( f . message . name , `${ PROTO_PATH } /${ f . message . file . name } .ts` ) ,
252- `},`
296+ `},` ,
253297 ) ;
254298 }
255299 field . push ( `},` ) ;
0 commit comments