Skip to content

Commit 84a1084

Browse files
authored
Improve generic builtin type inference with multiple parameters (#904)
1 parent 7912525 commit 84a1084

File tree

7 files changed

+236
-174
lines changed

7 files changed

+236
-174
lines changed

src/ast.ts

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1331,6 +1331,17 @@ export enum LiteralKind {
13311331
OBJECT
13321332
}
13331333

1334+
/** Checks if the given node represents a numeric (float or integer) literal. */
1335+
export function isNumericLiteral(node: Expression): bool {
1336+
if (node.kind == NodeKind.LITERAL) {
1337+
switch ((<LiteralExpression>node).literalKind) {
1338+
case LiteralKind.FLOAT:
1339+
case LiteralKind.INTEGER: return true;
1340+
}
1341+
}
1342+
return false;
1343+
}
1344+
13341345
/** Base class of all literal expressions. */
13351346
export abstract class LiteralExpression extends Expression {
13361347
kind = NodeKind.LITERAL;

src/builtins.ts

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ import {
2121
LiteralKind,
2222
LiteralExpression,
2323
StringLiteralExpression,
24-
CallExpression
24+
CallExpression,
25+
isNumericLiteral
2526
} from "./ast";
2627

2728
import {
@@ -1081,7 +1082,7 @@ export function compileCall(
10811082
) return module.unreachable();
10821083
let arg0 = typeArguments
10831084
? compiler.compileExpression(operands[0], typeArguments[0], Constraints.CONV_IMPLICIT | Constraints.MUST_WRAP)
1084-
: compiler.compileExpression(operands[0], Type.f64, Constraints.MUST_WRAP);
1085+
: compiler.compileExpression(operands[0], Type.auto, Constraints.MUST_WRAP);
10851086
let type = compiler.currentType;
10861087
if (!type.is(TypeFlags.REFERENCE)) {
10871088
switch (type.kind) {
@@ -1179,12 +1180,21 @@ export function compileCall(
11791180
checkTypeOptional(typeArguments, reportNode, compiler, true) |
11801181
checkArgsRequired(operands, 2, reportNode, compiler)
11811182
) return module.unreachable();
1183+
let left = operands[0];
11821184
let arg0 = typeArguments
1183-
? compiler.compileExpression(operands[0], typeArguments[0], Constraints.CONV_IMPLICIT | Constraints.MUST_WRAP)
1184-
: compiler.compileExpression(operands[0], Type.f64, Constraints.MUST_WRAP);
1185+
? compiler.compileExpression(left, typeArguments[0], Constraints.CONV_IMPLICIT | Constraints.MUST_WRAP)
1186+
: compiler.compileExpression(operands[0], Type.auto, Constraints.MUST_WRAP);
11851187
let type = compiler.currentType;
11861188
if (!type.is(TypeFlags.REFERENCE)) {
1187-
let arg1 = compiler.compileExpression(operands[1], type, Constraints.CONV_IMPLICIT | Constraints.MUST_WRAP);
1189+
let arg1: ExpressionRef;
1190+
if (!typeArguments && isNumericLiteral(left)) { // prefer right type
1191+
arg1 = compiler.compileExpression(operands[1], type, Constraints.MUST_WRAP);
1192+
if (compiler.currentType != type) {
1193+
arg0 = compiler.compileExpression(left, type = compiler.currentType, Constraints.CONV_IMPLICIT | Constraints.MUST_WRAP);
1194+
}
1195+
} else {
1196+
arg1 = compiler.compileExpression(operands[1], type, Constraints.CONV_IMPLICIT | Constraints.MUST_WRAP);
1197+
}
11881198
let op: BinaryOp = -1;
11891199
switch (type.kind) {
11901200
case TypeKind.I8:
@@ -1240,12 +1250,21 @@ export function compileCall(
12401250
checkTypeOptional(typeArguments, reportNode, compiler, true) |
12411251
checkArgsRequired(operands, 2, reportNode, compiler)
12421252
) return module.unreachable();
1253+
let left = operands[0];
12431254
let arg0 = typeArguments
1244-
? compiler.compileExpression(operands[0], typeArguments[0], Constraints.CONV_IMPLICIT | Constraints.MUST_WRAP)
1245-
: compiler.compileExpression(operands[0], Type.f64, Constraints.MUST_WRAP);
1255+
? compiler.compileExpression(left, typeArguments[0], Constraints.CONV_IMPLICIT | Constraints.MUST_WRAP)
1256+
: compiler.compileExpression(operands[0], Type.auto, Constraints.MUST_WRAP);
12461257
let type = compiler.currentType;
12471258
if (!type.is(TypeFlags.REFERENCE)) {
1248-
let arg1 = compiler.compileExpression(operands[1], type, Constraints.CONV_IMPLICIT | Constraints.MUST_WRAP);
1259+
let arg1: ExpressionRef;
1260+
if (!typeArguments && isNumericLiteral(left)) { // prefer right type
1261+
arg1 = compiler.compileExpression(operands[1], type, Constraints.MUST_WRAP);
1262+
if (compiler.currentType != type) {
1263+
arg0 = compiler.compileExpression(left, type = compiler.currentType, Constraints.CONV_IMPLICIT | Constraints.MUST_WRAP);
1264+
}
1265+
} else {
1266+
arg1 = compiler.compileExpression(operands[1], type, Constraints.CONV_IMPLICIT | Constraints.MUST_WRAP);
1267+
}
12491268
let op: BinaryOp = -1;
12501269
switch (type.kind) {
12511270
case TypeKind.I8:
@@ -1303,7 +1322,7 @@ export function compileCall(
13031322
) return module.unreachable();
13041323
let arg0 = typeArguments
13051324
? compiler.compileExpression(operands[0], typeArguments[0], Constraints.CONV_IMPLICIT)
1306-
: compiler.compileExpression(operands[0], Type.f64, Constraints.NONE);
1325+
: compiler.compileExpression(operands[0], Type.auto, Constraints.NONE);
13071326
let type = compiler.currentType;
13081327
if (!type.is(TypeFlags.REFERENCE)) {
13091328
switch (type.kind) {
@@ -1335,7 +1354,7 @@ export function compileCall(
13351354
) return module.unreachable();
13361355
let arg0 = typeArguments
13371356
? compiler.compileExpression(operands[0], typeArguments[0], Constraints.CONV_IMPLICIT)
1338-
: compiler.compileExpression(operands[0], Type.f64, Constraints.NONE);
1357+
: compiler.compileExpression(operands[0], Type.auto, Constraints.NONE);
13391358
let type = compiler.currentType;
13401359
if (!type.is(TypeFlags.REFERENCE)) {
13411360
switch (type.kind) {
@@ -1390,7 +1409,7 @@ export function compileCall(
13901409
) return module.unreachable();
13911410
let arg0 = typeArguments
13921411
? compiler.compileExpression(operands[0], typeArguments[0], Constraints.CONV_IMPLICIT)
1393-
: compiler.compileExpression(operands[0], Type.f64, Constraints.NONE);
1412+
: compiler.compileExpression(operands[0], Type.auto, Constraints.NONE);
13941413
let type = compiler.currentType;
13951414
if (!type.is(TypeFlags.REFERENCE)) {
13961415
switch (type.kind) {
@@ -1498,7 +1517,7 @@ export function compileCall(
14981517
) return module.unreachable();
14991518
let arg0 = typeArguments
15001519
? compiler.compileExpression(operands[0], typeArguments[0], Constraints.CONV_IMPLICIT)
1501-
: compiler.compileExpression(operands[0], Type.f64, Constraints.NONE);
1520+
: compiler.compileExpression(operands[0], Type.auto, Constraints.NONE);
15021521
let type = compiler.currentType;
15031522
if (!type.is(TypeFlags.REFERENCE)) {
15041523
switch (type.kind) {

tests/compiler/builtins.optimized.wat

Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,12 @@
428428
f64.const 1.25
429429
call $~lib/number/isFinite<f64>
430430
global.set $builtins/b
431+
f64.const 0
432+
global.set $builtins/F
433+
f32.const 0
434+
global.get $builtins/f
435+
f32.max
436+
global.set $builtins/f
431437
i32.const 8
432438
i32.load
433439
global.set $builtins/i
@@ -627,7 +633,7 @@
627633
if
628634
i32.const 0
629635
i32.const 64
630-
i32.const 294
636+
i32.const 299
631637
i32.const 0
632638
call $~lib/builtins/abort
633639
unreachable
@@ -638,7 +644,7 @@
638644
if
639645
i32.const 0
640646
i32.const 64
641-
i32.const 295
647+
i32.const 300
642648
i32.const 0
643649
call $~lib/builtins/abort
644650
unreachable
@@ -648,7 +654,7 @@
648654
if
649655
i32.const 0
650656
i32.const 64
651-
i32.const 296
657+
i32.const 301
652658
i32.const 0
653659
call $~lib/builtins/abort
654660
unreachable
@@ -658,7 +664,7 @@
658664
if
659665
i32.const 0
660666
i32.const 64
661-
i32.const 297
667+
i32.const 302
662668
i32.const 0
663669
call $~lib/builtins/abort
664670
unreachable
@@ -668,7 +674,7 @@
668674
if
669675
i32.const 0
670676
i32.const 64
671-
i32.const 298
677+
i32.const 303
672678
i32.const 0
673679
call $~lib/builtins/abort
674680
unreachable
@@ -678,7 +684,7 @@
678684
if
679685
i32.const 0
680686
i32.const 64
681-
i32.const 299
687+
i32.const 304
682688
i32.const 0
683689
call $~lib/builtins/abort
684690
unreachable
@@ -689,7 +695,7 @@
689695
if
690696
i32.const 0
691697
i32.const 64
692-
i32.const 300
698+
i32.const 305
693699
i32.const 0
694700
call $~lib/builtins/abort
695701
unreachable
@@ -700,7 +706,7 @@
700706
if
701707
i32.const 0
702708
i32.const 64
703-
i32.const 301
709+
i32.const 306
704710
i32.const 0
705711
call $~lib/builtins/abort
706712
unreachable
@@ -789,7 +795,7 @@
789795
if
790796
i32.const 0
791797
i32.const 64
792-
i32.const 437
798+
i32.const 442
793799
i32.const 2
794800
call $~lib/builtins/abort
795801
unreachable
@@ -801,7 +807,7 @@
801807
if
802808
i32.const 0
803809
i32.const 64
804-
i32.const 438
810+
i32.const 443
805811
i32.const 2
806812
call $~lib/builtins/abort
807813
unreachable
@@ -813,7 +819,7 @@
813819
if
814820
i32.const 0
815821
i32.const 64
816-
i32.const 439
822+
i32.const 444
817823
i32.const 2
818824
call $~lib/builtins/abort
819825
unreachable
@@ -825,7 +831,7 @@
825831
if
826832
i32.const 0
827833
i32.const 64
828-
i32.const 440
834+
i32.const 445
829835
i32.const 2
830836
call $~lib/builtins/abort
831837
unreachable
@@ -837,7 +843,7 @@
837843
if
838844
i32.const 0
839845
i32.const 64
840-
i32.const 441
846+
i32.const 446
841847
i32.const 2
842848
call $~lib/builtins/abort
843849
unreachable
@@ -849,7 +855,7 @@
849855
if
850856
i32.const 0
851857
i32.const 64
852-
i32.const 442
858+
i32.const 447
853859
i32.const 2
854860
call $~lib/builtins/abort
855861
unreachable
@@ -861,7 +867,7 @@
861867
if
862868
i32.const 0
863869
i32.const 64
864-
i32.const 443
870+
i32.const 448
865871
i32.const 2
866872
call $~lib/builtins/abort
867873
unreachable
@@ -873,7 +879,7 @@
873879
if
874880
i32.const 0
875881
i32.const 64
876-
i32.const 444
882+
i32.const 449
877883
i32.const 2
878884
call $~lib/builtins/abort
879885
unreachable
@@ -885,7 +891,7 @@
885891
if
886892
i32.const 0
887893
i32.const 64
888-
i32.const 445
894+
i32.const 450
889895
i32.const 2
890896
call $~lib/builtins/abort
891897
unreachable
@@ -897,7 +903,7 @@
897903
if
898904
i32.const 0
899905
i32.const 64
900-
i32.const 446
906+
i32.const 451
901907
i32.const 2
902908
call $~lib/builtins/abort
903909
unreachable
@@ -909,7 +915,7 @@
909915
if
910916
i32.const 0
911917
i32.const 64
912-
i32.const 447
918+
i32.const 452
913919
i32.const 2
914920
call $~lib/builtins/abort
915921
unreachable
@@ -921,7 +927,7 @@
921927
if
922928
i32.const 0
923929
i32.const 64
924-
i32.const 448
930+
i32.const 453
925931
i32.const 2
926932
call $~lib/builtins/abort
927933
unreachable
@@ -933,7 +939,7 @@
933939
if
934940
i32.const 0
935941
i32.const 64
936-
i32.const 449
942+
i32.const 454
937943
i32.const 2
938944
call $~lib/builtins/abort
939945
unreachable
@@ -945,7 +951,7 @@
945951
if
946952
i32.const 0
947953
i32.const 64
948-
i32.const 450
954+
i32.const 455
949955
i32.const 2
950956
call $~lib/builtins/abort
951957
unreachable
@@ -957,7 +963,7 @@
957963
if
958964
i32.const 0
959965
i32.const 64
960-
i32.const 451
966+
i32.const 456
961967
i32.const 2
962968
call $~lib/builtins/abort
963969
unreachable
@@ -969,7 +975,7 @@
969975
if
970976
i32.const 0
971977
i32.const 64
972-
i32.const 452
978+
i32.const 457
973979
i32.const 2
974980
call $~lib/builtins/abort
975981
unreachable
@@ -981,7 +987,7 @@
981987
if
982988
i32.const 0
983989
i32.const 64
984-
i32.const 453
990+
i32.const 458
985991
i32.const 2
986992
call $~lib/builtins/abort
987993
unreachable
@@ -993,7 +999,7 @@
993999
if
9941000
i32.const 0
9951001
i32.const 64
996-
i32.const 454
1002+
i32.const 459
9971003
i32.const 2
9981004
call $~lib/builtins/abort
9991005
unreachable
@@ -1005,7 +1011,7 @@
10051011
if
10061012
i32.const 0
10071013
i32.const 64
1008-
i32.const 455
1014+
i32.const 460
10091015
i32.const 2
10101016
call $~lib/builtins/abort
10111017
unreachable
@@ -1017,7 +1023,7 @@
10171023
if
10181024
i32.const 0
10191025
i32.const 64
1020-
i32.const 456
1026+
i32.const 461
10211027
i32.const 2
10221028
call $~lib/builtins/abort
10231029
unreachable

tests/compiler/builtins.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,11 @@ F = trunc<f64>(1.25);
158158
b = isNaN<f64>(1.25);
159159
b = isFinite<f64>(1.25);
160160

161+
// prefer right type if left is a numeric literal
162+
163+
F = min(0, 1.0);
164+
f = max(0, f);
165+
161166
// load and store
162167

163168
i = load<i32>(8); store<i32>(8, i);

0 commit comments

Comments
 (0)