Skip to content

Commit 3a71933

Browse files
authored
Fix generic equals on Chez (#429) (#970)
Resolves #429.
1 parent 45de8c4 commit 3a71933

File tree

5 files changed

+134
-8
lines changed

5 files changed

+134
-8
lines changed

effekt/jvm/src/test/scala/effekt/ChezSchemeTests.scala

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,6 @@ abstract class ChezSchemeTests extends EffektTests {
6060

6161
examplesDir / "pos" / "io", // async io is only implemented for monadic JS
6262

63-
64-
examplesDir / "pos" / "issue429.effekt",
65-
6663
// Generic comparison
6764
examplesDir / "pos" / "genericcompare.effekt",
6865
examplesDir / "pos" / "issue733.effekt",

examples/pos/issue429.check

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,29 @@
1+
=== Basic equality tests ===
12
true
23
true
34
true
45
true
5-
true
6+
true
7+
8+
=== Record equality tests ===
9+
true
10+
false
11+
true
12+
13+
=== Nested structure tests ===
14+
true
15+
false
16+
true
17+
false
18+
19+
=== Tree equality tests ===
20+
true
21+
false
22+
false
23+
true
24+
25+
=== Complex nested structure tests ===
26+
true
27+
false
28+
true
29+
false

examples/pos/issue429.effekt

Lines changed: 76 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,89 @@
1+
// Test generic equality implementation
12
type MyType {
23
MySingleCase()
34
}
45

56
record EmptyRecord()
67

8+
record Person(name: String, age: Int)
9+
10+
// Binary tree with values in nodes
11+
type Tree[T] {
12+
Leaf();
13+
Node(value: T, left: Tree[T], right: Tree[T])
14+
}
15+
16+
// Create a simple tree for testing
17+
def createTree[T](depth: Int, value: T): Tree[T] = {
18+
if (depth <= 0) {
19+
Leaf()
20+
} else {
21+
Node(value, createTree(depth - 1, value), createTree(depth - 1, value))
22+
}
23+
}
24+
25+
// Create a more complex tree with different values
26+
def createMixedTree(): Tree[Int] = {
27+
Node(1,
28+
Node(2, Leaf(), Leaf()),
29+
Node(3,
30+
Node(4, Leaf(), Leaf()),
31+
Leaf()
32+
)
33+
)
34+
}
35+
736
def main() = {
37+
// Original tests
38+
println("=== Basic equality tests ===")
839
println(MySingleCase().equals(MySingleCase())) // ~> true
940
println(EmptyRecord().equals(EmptyRecord())) // ~> true
10-
1141
println(Some(MySingleCase()).equals(Some(MySingleCase()))) // ~> true
1242
println(Some(EmptyRecord()).equals(Some(EmptyRecord()))) // ~> true
13-
1443
println([Some(EmptyRecord()), Some(EmptyRecord()), None()].equals([Some(EmptyRecord()), Some(EmptyRecord()), None()]))
44+
45+
// Enhanced tests with records and nested structures
46+
println("\n=== Record equality tests ===")
47+
val person1 = Person("Alice", 30)
48+
val person2 = Person("Alice", 30)
49+
val person3 = Person("Bob", 25)
50+
println(person1.equals(person2)) // ~> true
51+
println(person1.equals(person3)) // ~> false
52+
println(Person("Alice", 30).equals(Person("Alice", 30))) // ~> true
53+
54+
// Test nested records
55+
println("\n=== Nested structure tests ===")
56+
println(Some(person1).equals(Some(person2))) // ~> true
57+
println(Some(person1).equals(Some(person3))) // ~> false
58+
println([person1, person2].equals([person1, person2])) // ~> true
59+
println([person1, person3].equals([person1, person2])) // ~> false
60+
61+
// Tree equality tests
62+
println("\n=== Tree equality tests ===")
63+
val tree1 = createTree(3, 42)
64+
val tree2 = createTree(3, 42)
65+
val tree3 = createTree(2, 42)
66+
val tree4 = createTree(3, 100)
67+
68+
println(tree1.equals(tree2)) // ~> true
69+
println(tree1.equals(tree3)) // ~> false
70+
println(tree1.equals(tree4)) // ~> false
71+
72+
val mixedTree1 = createMixedTree()
73+
val mixedTree2 = createMixedTree()
74+
println(mixedTree1.equals(mixedTree2)) // ~> true
75+
76+
// Even more complex nested structures
77+
println("\n=== Complex nested structure tests ===")
78+
val nestedList1 = [Some(tree1), None(), Some(mixedTree1)]
79+
val nestedList2 = [Some(tree2), None(), Some(mixedTree2)]
80+
val nestedList3 = [Some(tree3), None(), Some(mixedTree1)]
81+
82+
println(nestedList1.equals(nestedList2)) // ~> true
83+
println(nestedList1.equals(nestedList3)) // ~> false
84+
85+
val complexRecord = Person("Charlie", 35)
86+
val recordWithTree = [complexRecord, Person("Dave", 40), Person("Charlie", 35)]
87+
println(recordWithTree.equals([complexRecord, Person("Dave", 40), Person("Charlie", 35)])) // ~> true
88+
println(recordWithTree.equals([complexRecord, Person("Dave", 40), Person("Charlie", 36)])) // ~> false
1589
}

libraries/chez/common/effekt_primitives.ss

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,39 @@
6464
(display str)
6565
(newline))
6666

67+
; Custom structural equality that properly handles records
6768
(define (equal_impl obj1 obj2)
68-
(equal? obj1 obj2))
69+
(cond
70+
; Same object reference (fast path)
71+
[(eq? obj1 obj2) #t]
72+
73+
; If both are records, compare them structurally
74+
[(and (record? obj1) (record? obj2))
75+
(let* ([rtd1 (record-rtd obj1)]
76+
[rtd2 (record-rtd obj2)])
77+
; Check if same record type
78+
(if (eq? rtd1 rtd2)
79+
(let* ([n (vector-length (record-type-field-names rtd1))]
80+
[result #t])
81+
; Compare all fields recursively
82+
(do ([i 0 (+ i 1)])
83+
((or (= i n) (not result)) result)
84+
(let ([field1 ((record-accessor rtd1 i) obj1)]
85+
[field2 ((record-accessor rtd2 i) obj2)])
86+
(if (not (equal_impl field1 field2))
87+
(set! result #f)))))
88+
#f))]
89+
90+
; For lists, compare elements recursively
91+
[(and (list? obj1) (list? obj2))
92+
(and (= (length obj1) (length obj2))
93+
(let loop ([l1 obj1] [l2 obj2])
94+
(or (null? l1)
95+
(and (equal_impl (car l1) (car l2))
96+
(loop (cdr l1) (cdr l2))))))]
97+
98+
; For all other types, use Scheme's built-in equal?
99+
[else (equal? obj1 obj2)]))
69100

70101
(define-syntax thunk
71102
(syntax-rules ()

libraries/common/effekt.effekt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def println(o: Ordering): Unit = println(o.show)
203203
/// Structural equality: Not available in the LLVM backend
204204
extern pure def equals[R](x: R, y: R): Bool =
205205
js "$effekt.equals(${x}, ${y})"
206-
chez "(equal? ${x} ${y})"
206+
chez "(equal_impl ${x} ${y})"
207207

208208
def differsFrom[R](x: R, y: R): Bool =
209209
not(equals(x, y))

0 commit comments

Comments
 (0)