@@ -190,7 +190,7 @@ const decoderGenerators: Record<string, (jsTypeOverride?: 'number' | 'string') =
190
190
191
191
const defaultValueGenerators : Record < string , ( ) => string > = {
192
192
bool : ( ) => 'false' ,
193
- bytes : ( ) => 'new Uint8Array (0)' ,
193
+ bytes : ( ) => 'uint8ArrayAlloc (0)' ,
194
194
double : ( ) => '0' ,
195
195
fixed32 : ( ) => '0' ,
196
196
fixed64 : ( ) => '0n' ,
@@ -320,6 +320,10 @@ function createDefaultObject (fields: Record<string, FieldDef>, messageDef: Mess
320
320
defaultValueGenerator = defaultValueGeneratorsJsTypeOverrides [ jsTypeOverride ]
321
321
}
322
322
323
+ if ( type === 'bytes' ) {
324
+ moduleDef . addImport ( 'uint8arrays/alloc' , 'alloc' , 'uint8ArrayAlloc' )
325
+ }
326
+
323
327
defaultValue = defaultValueGenerator ( )
324
328
} else {
325
329
const def = findDef ( fieldDef . type , messageDef , moduleDef )
@@ -457,7 +461,7 @@ function defineFields (fields: Record<string, FieldDef>, messageDef: MessageDef,
457
461
458
462
function compileMessage ( messageDef : MessageDef , moduleDef : ModuleDef , flags ?: Flags ) : string {
459
463
if ( isEnumDef ( messageDef ) ) {
460
- moduleDef . imports . add ( 'enumeration' )
464
+ moduleDef . addImport ( 'protons-runtime' , 'enumeration' )
461
465
462
466
// check that the enum def values start from 0
463
467
if ( Object . values ( messageDef . values ) [ 0 ] !== 0 ) {
@@ -510,10 +514,11 @@ export namespace ${messageDef.name} {
510
514
const fields = messageDef . fields ?? { }
511
515
512
516
// import relevant modules
513
- moduleDef . imports . add ( 'encodeMessage' )
514
- moduleDef . imports . add ( 'decodeMessage' )
515
- moduleDef . imports . add ( 'message' )
516
- moduleDef . importedTypes . add ( 'Codec' )
517
+ moduleDef . addImport ( 'protons-runtime' , 'encodeMessage' )
518
+ moduleDef . addImport ( 'protons-runtime' , 'decodeMessage' )
519
+ moduleDef . addImport ( 'protons-runtime' , 'message' )
520
+ moduleDef . addTypeImport ( 'protons-runtime' , 'Codec' )
521
+ moduleDef . addTypeImport ( 'uint8arraylist' , 'Uint8ArrayList' )
517
522
518
523
const interfaceFields = defineFields ( fields , messageDef , moduleDef )
519
524
. join ( '\n ' )
@@ -544,10 +549,10 @@ export interface ${messageDef.name} {
544
549
545
550
if ( codec == null ) {
546
551
if ( fieldDef . enum ) {
547
- moduleDef . imports . add ( 'enumeration' )
552
+ moduleDef . addImport ( 'protons-runtime' , 'enumeration' )
548
553
type = 'enum'
549
554
} else {
550
- moduleDef . imports . add ( 'message' )
555
+ moduleDef . addImport ( 'protons-runtime' , 'message' )
551
556
type = 'message'
552
557
}
553
558
@@ -669,10 +674,10 @@ export interface ${messageDef.name} {
669
674
670
675
if ( codec == null ) {
671
676
if ( fieldDef . enum ) {
672
- moduleDef . imports . add ( 'enumeration' )
677
+ moduleDef . addImport ( 'protons-runtime' , 'enumeration' )
673
678
type = 'enum'
674
679
} else {
675
- moduleDef . imports . add ( 'message' )
680
+ moduleDef . addImport ( 'protons-runtime' , 'message' )
676
681
type = 'message'
677
682
}
678
683
@@ -689,7 +694,7 @@ export interface ${messageDef.name} {
689
694
let limit = ''
690
695
691
696
if ( fieldDef . lengthLimit != null ) {
692
- moduleDef . imports . add ( 'CodeError' )
697
+ moduleDef . addImport ( 'protons-runtime' , 'CodeError' )
693
698
694
699
limit = `
695
700
if (obj.${ fieldName } .size === ${ fieldDef . lengthLimit } ) {
@@ -707,7 +712,7 @@ export interface ${messageDef.name} {
707
712
let limit = ''
708
713
709
714
if ( fieldDef . lengthLimit != null ) {
710
- moduleDef . imports . add ( 'CodeError' )
715
+ moduleDef . addImport ( 'protons-runtime' , 'CodeError' )
711
716
712
717
limit = `
713
718
if (obj.${ fieldName } .length === ${ fieldDef . lengthLimit } ) {
@@ -785,23 +790,83 @@ export namespace ${messageDef.name} {
785
790
` . trimStart ( )
786
791
}
787
792
788
- interface ModuleDef {
789
- imports : Set < string >
790
- importedTypes : Set < string >
793
+ interface Import {
794
+ symbol : string
795
+ alias ?: string
796
+ type : boolean
797
+ }
798
+
799
+ class ModuleDef {
800
+ imports : Map < string , Import [ ] >
791
801
types : Set < string >
792
802
compiled : string [ ]
793
803
globals : Record < string , ClassDef >
794
- }
795
804
796
- function defineModule ( def : ClassDef , flags : Flags ) : ModuleDef {
797
- const moduleDef : ModuleDef = {
798
- imports : new Set ( ) ,
799
- importedTypes : new Set ( ) ,
800
- types : new Set ( ) ,
801
- compiled : [ ] ,
802
- globals : { }
805
+ constructor ( ) {
806
+ this . imports = new Map ( )
807
+ this . types = new Set ( )
808
+ this . compiled = [ ]
809
+ this . globals = { }
810
+ }
811
+
812
+ addImport ( module : string , symbol : string , alias ?: string ) : void {
813
+ const defs = this . _findDefs ( module )
814
+
815
+ for ( const def of defs ) {
816
+ // check if we already have a definition for this symbol
817
+ if ( def . symbol === symbol ) {
818
+ if ( alias !== def . alias ) {
819
+ throw new Error ( `Type symbol ${ symbol } imported from ${ module } with alias ${ def . alias } does not match alias ${ alias } ` )
820
+ }
821
+
822
+ // if it was a type before it's not now
823
+ def . type = false
824
+ return
825
+ }
826
+ }
827
+
828
+ defs . push ( {
829
+ symbol,
830
+ alias,
831
+ type : false
832
+ } )
833
+ }
834
+
835
+ addTypeImport ( module : string , symbol : string , alias ?: string ) : void {
836
+ const defs = this . _findDefs ( module )
837
+
838
+ for ( const def of defs ) {
839
+ // check if we already have a definition for this symbol
840
+ if ( def . symbol === symbol ) {
841
+ if ( alias !== def . alias ) {
842
+ throw new Error ( `Type symbol ${ symbol } imported from ${ module } with alias ${ def . alias } does not match alias ${ alias } ` )
843
+ }
844
+
845
+ return
846
+ }
847
+ }
848
+
849
+ defs . push ( {
850
+ symbol,
851
+ alias,
852
+ type : true
853
+ } )
803
854
}
804
855
856
+ _findDefs ( module : string ) : Import [ ] {
857
+ let defs = this . imports . get ( module )
858
+
859
+ if ( defs == null ) {
860
+ defs = [ ]
861
+ this . imports . set ( module , defs )
862
+ }
863
+
864
+ return defs
865
+ }
866
+ }
867
+
868
+ function defineModule ( def : ClassDef , flags : Flags ) : ModuleDef {
869
+ const moduleDef = new ModuleDef ( )
805
870
const defs = def . nested
806
871
807
872
if ( defs == null ) {
@@ -963,28 +1028,48 @@ export async function generate (source: string, flags: Flags): Promise<void> {
963
1028
]
964
1029
965
1030
const imports = [ ]
1031
+ const importedModules = Array . from ( [ ...moduleDef . imports . entries ( ) ] )
1032
+ . sort ( ( a , b ) => {
1033
+ return a [ 0 ] . localeCompare ( b [ 0 ] )
1034
+ } )
1035
+ . sort ( ( a , b ) => {
1036
+ const aAllTypes = a [ 1 ] . reduce ( ( acc , curr ) => {
1037
+ return acc && curr . type
1038
+ } , true )
966
1039
967
- if ( moduleDef . imports . size > 0 ) {
968
- imports . push ( `import { ${ Array . from ( moduleDef . imports ) . join ( ', ' ) } } from 'protons-runtime'` )
969
- }
1040
+ const bAllTypes = b [ 1 ] . reduce ( ( acc , curr ) => {
1041
+ return acc && curr . type
1042
+ } , true )
970
1043
971
- if ( moduleDef . imports . has ( 'encodeMessage' ) ) {
972
- imports . push ( "import type { Uint8ArrayList } from 'uint8arraylist'" )
973
- }
1044
+ if ( aAllTypes && ! bAllTypes ) {
1045
+ return 1
1046
+ }
1047
+
1048
+ if ( ! aAllTypes && bAllTypes ) {
1049
+ return - 1
1050
+ }
974
1051
975
- if ( moduleDef . importedTypes . size > 0 ) {
976
- imports . push ( `import type { ${ Array . from ( moduleDef . importedTypes ) . join ( ', ' ) } } from 'protons-runtime'` )
1052
+ return 0
1053
+ } )
1054
+
1055
+ for ( const imp of importedModules ) {
1056
+ const allTypes = imp [ 1 ] . reduce ( ( acc , curr ) => {
1057
+ return acc && curr . type
1058
+ } , true )
1059
+
1060
+ const symbols = imp [ 1 ] . sort ( ( a , b ) => {
1061
+ return a . symbol . localeCompare ( b . symbol )
1062
+ } ) . map ( imp => {
1063
+ return `${ ! allTypes && imp . type ? 'type ' : '' } ${ imp . symbol } ${ imp . alias != null ? ` as ${ imp . alias } ` : '' } `
1064
+ } ) . join ( ', ' )
1065
+
1066
+ imports . push ( `import ${ allTypes ? 'type ' : '' } { ${ symbols } } from '${ imp [ 0 ] } '` )
977
1067
}
978
1068
979
1069
const lines = [
980
1070
...ignores ,
981
1071
'' ,
982
- ...imports . sort ( ( a , b ) => {
983
- const aModule = a . split ( "from '" ) [ 1 ] . toString ( )
984
- const bModule = b . split ( "from '" ) [ 1 ] . toString ( )
985
-
986
- return aModule . localeCompare ( bModule )
987
- } ) ,
1072
+ ...imports ,
988
1073
'' ,
989
1074
...moduleDef . compiled
990
1075
]
0 commit comments