Skip to content

Commit 1cba005

Browse files
feat(runtime): Optimize == for lists (#2247)
Co-authored-by: Oscar Spencer <oscar.spen@gmail.com>
1 parent 097ae7d commit 1cba005

File tree

3 files changed

+104
-76
lines changed

3 files changed

+104
-76
lines changed

compiler/test/stdlib/pervasives.test.gr

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,9 @@ record Comparable2 {
7575
}
7676
assert compare({ a: 1, b: true, c: void }, { a: 1, b: true, c: void }) == 0
7777
assert compare({ a: 1, b: true, c: void }, { a: 1, b: false, c: void }) > 0
78+
79+
// Large list equality, regression #2247
80+
let rec make_list = (n, acc) => {
81+
if (n == 0) acc else make_list(n - 1, [n, ...acc])
82+
}
83+
assert make_list(500_000, []) == make_list(500_000, [])

compiler/test/suites/basic_functionality.re

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,6 @@ describe("basic functionality", ({test, testSkip}) => {
377377
~config_fn=smallestFileConfig,
378378
"smallest_grain_program",
379379
"",
380-
6494,
380+
6503,
381381
);
382382
});

stdlib/runtime/equal.gr

Lines changed: 97 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,19 @@ module Equal
33

44
from "runtime/unsafe/memory" include Memory
55
from "runtime/unsafe/wasmi32" include WasmI32
6-
use WasmI32.{ (==), (!=), (&), (^), (+), (-), (*), (<), remS as (%), (<<) }
6+
use WasmI32.{
7+
(==),
8+
(!=),
9+
(&),
10+
(^),
11+
(+),
12+
(-),
13+
(*),
14+
(<),
15+
remS as (%),
16+
(<<),
17+
(>>),
18+
}
719
from "runtime/unsafe/wasmi64" include WasmI64
820
from "runtime/unsafe/wasmf32" include WasmF32
921
from "runtime/unsafe/tags" include Tags
@@ -14,6 +26,10 @@ primitive (!) = "@not"
1426
primitive (||) = "@or"
1527
primitive (&&) = "@and"
1628
primitive ignore = "@ignore"
29+
primitive builtinId = "@builtin.id"
30+
31+
@unsafe
32+
let _LIST_ID = WasmI32.fromGrain(builtinId("List"))
1733

1834
@unsafe
1935
let cycleMarker = 0x80000000n
@@ -23,38 +39,47 @@ let rec heapEqualHelp = (heapTag, xptr, yptr) => {
2339
match (heapTag) {
2440
t when t == Tags._GRAIN_ADT_HEAP_TAG => {
2541
// Check if the same constructor variant
26-
if (WasmI32.load(xptr, 12n) != WasmI32.load(yptr, 12n)) {
27-
false
42+
let mut xVariantTag = WasmI32.load(xptr, 12n)
43+
let mut yVariantTag = WasmI32.load(yptr, 12n)
44+
if (xVariantTag != yVariantTag) {
45+
return false
46+
}
47+
48+
// Handle lists separately to avoid stack overflow
49+
if (WasmI32.load(xptr, 8n) == _LIST_ID) {
50+
if (xVariantTag >> 1n == 1n) return true // End of list
51+
52+
if (!equalHelp(WasmI32.load(xptr, 20n), WasmI32.load(yptr, 20n))) {
53+
return false
54+
}
55+
56+
return equalHelp(WasmI32.load(xptr, 24n), WasmI32.load(yptr, 24n))
2857
} else {
2958
let xarity = WasmI32.load(xptr, 16n)
3059
let yarity = WasmI32.load(yptr, 16n)
3160

3261
// Cycle check
3362
if ((xarity & cycleMarker) == cycleMarker) {
34-
true
35-
} else {
36-
WasmI32.store(xptr, xarity ^ cycleMarker, 16n)
37-
WasmI32.store(yptr, yarity ^ cycleMarker, 16n)
38-
39-
let mut result = true
40-
41-
let bytes = xarity * 4n
42-
for (let mut i = 0n; i < bytes; i += 4n) {
43-
if (
44-
!equalHelp(
45-
WasmI32.load(xptr + i, 20n),
46-
WasmI32.load(yptr + i, 20n)
47-
)
48-
) {
49-
result = false
50-
break
51-
}
52-
}
53-
WasmI32.store(xptr, xarity, 16n)
54-
WasmI32.store(yptr, yarity, 16n)
63+
return true
64+
}
5565

56-
result
66+
WasmI32.store(xptr, xarity ^ cycleMarker, 16n)
67+
WasmI32.store(yptr, yarity ^ cycleMarker, 16n)
68+
69+
let bytes = xarity * 4n
70+
for (let mut i = 0n; i < bytes; i += 4n) {
71+
if (
72+
!equalHelp(WasmI32.load(xptr + i, 20n), WasmI32.load(yptr + i, 20n))
73+
) {
74+
WasmI32.store(xptr, xarity, 16n)
75+
WasmI32.store(yptr, yarity, 16n)
76+
return false
77+
}
5778
}
79+
WasmI32.store(xptr, xarity, 16n)
80+
WasmI32.store(yptr, yarity, 16n)
81+
82+
return true
5883
}
5984
},
6085
t when t == Tags._GRAIN_RECORD_HEAP_TAG => {
@@ -63,65 +88,64 @@ let rec heapEqualHelp = (heapTag, xptr, yptr) => {
6388

6489
// Cycle check
6590
if ((xlength & cycleMarker) == cycleMarker) {
66-
true
67-
} else {
68-
WasmI32.store(xptr, xlength ^ cycleMarker, 12n)
69-
WasmI32.store(yptr, ylength ^ cycleMarker, 12n)
70-
71-
let mut result = true
91+
return true
92+
}
7293

73-
let bytes = xlength * 4n
74-
for (let mut i = 0n; i < bytes; i += 4n) {
75-
if (
76-
!equalHelp(WasmI32.load(xptr + i, 16n), WasmI32.load(yptr + i, 16n))
77-
) {
78-
result = false
79-
break
80-
}
94+
WasmI32.store(xptr, xlength ^ cycleMarker, 12n)
95+
WasmI32.store(yptr, ylength ^ cycleMarker, 12n)
96+
97+
let bytes = xlength * 4n
98+
for (let mut i = 0n; i < bytes; i += 4n) {
99+
if (
100+
!equalHelp(WasmI32.load(xptr + i, 16n), WasmI32.load(yptr + i, 16n))
101+
) {
102+
WasmI32.store(xptr, xlength, 12n)
103+
WasmI32.store(yptr, ylength, 12n)
104+
return false
81105
}
82-
WasmI32.store(xptr, xlength, 12n)
83-
WasmI32.store(yptr, ylength, 12n)
84-
85-
result
86106
}
107+
WasmI32.store(xptr, xlength, 12n)
108+
WasmI32.store(yptr, ylength, 12n)
109+
110+
return true
87111
},
88112
t when t == Tags._GRAIN_ARRAY_HEAP_TAG => {
89113
let xlength = WasmI32.load(xptr, 4n)
90114
let ylength = WasmI32.load(yptr, 4n)
91115

92116
// Check if the same length
93117
if (xlength != ylength) {
94-
false
95-
} else if ((xlength & cycleMarker) == cycleMarker) {
96-
// Cycle check
97-
true
98-
} else {
99-
WasmI32.store(xptr, xlength ^ cycleMarker, 4n)
100-
WasmI32.store(yptr, ylength ^ cycleMarker, 4n)
118+
return false
119+
}
101120

102-
let mut result = true
103-
let bytes = xlength * 4n
104-
for (let mut i = 0n; i < bytes; i += 4n) {
105-
if (
106-
!equalHelp(WasmI32.load(xptr + i, 8n), WasmI32.load(yptr + i, 8n))
107-
) {
108-
result = false
109-
break
110-
}
111-
}
121+
// Cycle check
122+
if ((xlength & cycleMarker) == cycleMarker) {
123+
return true
124+
}
112125

113-
WasmI32.store(xptr, xlength, 4n)
114-
WasmI32.store(yptr, ylength, 4n)
126+
WasmI32.store(xptr, xlength ^ cycleMarker, 4n)
127+
WasmI32.store(yptr, ylength ^ cycleMarker, 4n)
115128

116-
result
129+
let bytes = xlength * 4n
130+
for (let mut i = 0n; i < bytes; i += 4n) {
131+
if (!equalHelp(WasmI32.load(xptr + i, 8n), WasmI32.load(yptr + i, 8n))) {
132+
WasmI32.store(xptr, xlength, 4n)
133+
WasmI32.store(yptr, ylength, 4n)
134+
return false
135+
}
117136
}
137+
138+
WasmI32.store(xptr, xlength, 4n)
139+
WasmI32.store(yptr, ylength, 4n)
140+
141+
return true
118142
},
119143
t when t == Tags._GRAIN_STRING_HEAP_TAG || t == Tags._GRAIN_BYTES_HEAP_TAG => {
120144
let xlength = WasmI32.load(xptr, 4n)
121145
let ylength = WasmI32.load(yptr, 4n)
122146

123147
// Check if the same length
124-
if (xlength != ylength) {
148+
return if (xlength != ylength) {
125149
false
126150
} else {
127151
Memory.compare(xptr + 8n, yptr + 8n, xlength) == 0n
@@ -132,44 +156,42 @@ let rec heapEqualHelp = (heapTag, xptr, yptr) => {
132156
let ysize = WasmI32.load(yptr, 4n)
133157

134158
if ((xsize & cycleMarker) == cycleMarker) {
135-
true
159+
return true
136160
} else {
137161
WasmI32.store(xptr, xsize ^ cycleMarker, 4n)
138162
WasmI32.store(yptr, ysize ^ cycleMarker, 4n)
139163

140-
let mut result = true
141164
let bytes = xsize * 4n
142165
for (let mut i = 0n; i < bytes; i += 4n) {
143166
if (
144167
!equalHelp(WasmI32.load(xptr + i, 8n), WasmI32.load(yptr + i, 8n))
145168
) {
146-
result = false
147-
break
169+
WasmI32.store(xptr, xsize, 4n)
170+
WasmI32.store(yptr, ysize, 4n)
171+
return false
148172
}
149173
}
150174

151175
WasmI32.store(xptr, xsize, 4n)
152176
WasmI32.store(yptr, ysize, 4n)
153177

154-
result
178+
return true
155179
}
156180
},
157181
t when t == Tags._GRAIN_UINT32_HEAP_TAG || t == Tags._GRAIN_INT32_HEAP_TAG => {
158182
let xval = WasmI32.load(xptr, 4n)
159183
let yval = WasmI32.load(yptr, 4n)
160-
xval == yval
184+
return xval == yval
161185
},
162186
// Float32 is handled by equalHelp directly
163187
t when t == Tags._GRAIN_UINT64_HEAP_TAG => {
164188
use WasmI64.{ (==) }
165189
let xval = WasmI64.load(xptr, 8n)
166190
let yval = WasmI64.load(yptr, 8n)
167-
xval == yval
168-
},
169-
_ => {
170-
// No other implementation
171-
xptr == yptr
191+
return xval == yval
172192
},
193+
// No other implementation
194+
_ => return xptr == yptr,
173195
}
174196
}
175197
and equalHelp = (x, y) => {

0 commit comments

Comments
 (0)