Skip to content

Commit 12c24f3

Browse files
author
Oron Port
committed
feat: Implement explicit conditional expression assignment transformation
This commit introduces a new compiler stage `ExplicitCondExprAssign` that transforms conditional expressions into explicit assignments. Key changes include: - Adding a new stage to convert anonymous conditional expressions to direct assignments - Extending `DFRef` with an `isTypeRef` method to filter type references - Updating test cases to demonstrate the new transformation - Modifying reference Verilog and Scala examples to reflect the new assignment style
1 parent c304ff9 commit 12c24f3

File tree

19 files changed

+354
-124
lines changed

19 files changed

+354
-124
lines changed

compiler/ir/src/main/scala/dfhdl/compiler/ir/DFRef.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ object DFRef:
4242
val originRefType = classTag[DFVal.CanBeExpr]
4343
override def copyAsNewRef: this.type = new TypeRef {}.asInstanceOf[this.type]
4444

45+
extension (ref: DFRefAny)
46+
def isTypeRef: Boolean = ref match
47+
case ref: TypeRef => true
48+
case _ => false
4549
def unapply[M <: DFMember](ref: DFRef[M])(using MemberGetSet): Option[M] = Some(ref.get)
4650
end DFRef
4751

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
package dfhdl.compiler.stages
2+
3+
import dfhdl.compiler.analysis.*
4+
import dfhdl.compiler.ir.*
5+
import dfhdl.compiler.patching.*
6+
import dfhdl.internals.*
7+
import dfhdl.options.CompilerOptions
8+
import scala.reflect.classTag
9+
10+
/* This stage transforms an assignment from a conditional expression to a statement.*/
11+
case object ExplicitCondExprAssign extends Stage:
12+
def dependencies: List[Stage] = List()
13+
def nullifies: Set[Stage] = Set(DropUnreferencedAnons)
14+
15+
object IgnoreTypeRefs extends Patch.Replace.RefFilter:
16+
def apply(refs: Set[DFRefAny])(using MemberGetSet): Set[DFRefAny] =
17+
refs.filterNot(_.isTypeRef)
18+
19+
def transform(designDB: DB)(using MemberGetSet, CompilerOptions): DB =
20+
var headers = List.empty[DFConditional.Header]
21+
extension (ch: DFConditional.Header)
22+
// recursive call to patch conditional block chains
23+
private def patchChains(headerVar: DFVal, op: DFNet.Op): List[(DFMember, Patch)] =
24+
// changing type of header to unit, since the expression is now a statement
25+
headers = ch :: headers
26+
val cbChain = getSet.designDB.conditionalChainTable(ch)
27+
val lastMembers = cbChain.map(_.members(MemberView.Folded).last)
28+
lastMembers.flatMap {
29+
case ident @ Ident(underlying: DFConditional.Header) =>
30+
ident -> Patch.Remove() :: underlying.patchChains(headerVar, op)
31+
case ident @ Ident(underlying) =>
32+
val assignDsn = new MetaDesign(
33+
ident,
34+
Patch.Add.Config.After
35+
):
36+
(op: @unchecked) match
37+
case DFNet.Op.Assignment =>
38+
headerVar.asVarAny.:=(underlying.asValAny)(using
39+
dfc.setMetaAnon(ident.meta.position)
40+
)
41+
case DFNet.Op.NBAssignment =>
42+
import dfhdl.core.nbassign
43+
headerVar.asVarAny.nbassign(underlying.asValAny)(using
44+
dfc.setMetaAnon(ident.meta.position)
45+
)
46+
ident -> Patch.Remove() :: assignDsn.patch :: Nil
47+
case _ => ??? // not possible
48+
}
49+
end patchChains
50+
private def patchChainsNet(
51+
headerVar: DFVal,
52+
net: DFNet,
53+
op: DFNet.Op
54+
): List[(DFMember, Patch)] =
55+
val removeNetPatch = net -> Patch.Remove()
56+
removeNetPatch :: ch.patchChains(headerVar, op)
57+
end extension
58+
val patchList1 = designDB.members.view
59+
// collect all the assignments from anonymous conditionals
60+
.flatMap {
61+
case net @ DFNet.Assignment(toVal, header: DFConditional.Header) if header.isAnonymous =>
62+
header.patchChainsNet(toVal, net, net.op)
63+
case net @ DFNet.Connection(toVal: DFVal, header: DFConditional.Header, _)
64+
if !net.isViaConnection && header.isAnonymous && (toVal.isPortOut || toVal.isVar) =>
65+
header.patchChainsNet(toVal, net, DFNet.Op.Assignment)
66+
case _ => Nil
67+
}.toList
68+
val patchList2 = headers.map { h =>
69+
h -> Patch.Replace(
70+
h.updateDFType(DFUnit),
71+
Patch.Replace.Config.FullReplacement,
72+
IgnoreTypeRefs
73+
)
74+
}
75+
designDB
76+
.patch(patchList1)
77+
.patch(patchList2)
78+
end transform
79+
end ExplicitCondExprAssign
80+
81+
extension [T: HasDB](t: T)
82+
def explicitCondExprAssign(using CompilerOptions): DB =
83+
StageRunner.run(ExplicitCondExprAssign)(t.db)

compiler/stages/src/main/scala/dfhdl/compiler/stages/NamedAliases.scala

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,15 @@ case object NamedVerilogSelection extends NamedAliases:
9696
if !lhs.hasVerilogName && carryOps.contains(op) && func.width > lhs.width =>
9797
List(lhs)
9898
// anonymous conditional expressions
99-
case ch: DFConditional.Header if ch.isAnonymous && ch.dfType != DFUnit => List(ch)
100-
case _ => Nil
99+
case ch: DFConditional.Header if ch.isAnonymous && ch.dfType != DFUnit =>
100+
ch.getReadDeps.head match
101+
// if the conditional is referred from a net, it is not a selection to be named
102+
case net: DFNet => Nil
103+
// if the conditional is referred from an ident, it is not a selection to be named
104+
case Ident(_) => Nil
105+
// otherwise, it is a selection to be named
106+
case _ => List(ch)
107+
case _ => Nil
101108
end NamedVerilogSelection
102109

103110
extension [T: HasDB](t: T)
@@ -134,6 +141,7 @@ case object NamedAnonMultiref extends NamedAliases, NoCheckStage:
134141
//Names anonymous conditional expressions, as long as they are not referenced by an ident which indicates that
135142
//they are themselves inside another conditional expression
136143
case object NamedAnonCondExpr extends NamedAliases:
144+
override def dependencies: List[Stage] = List(ExplicitCondExprAssign)
137145
def criteria(dfVal: DFVal)(using MemberGetSet): List[DFVal] = dfVal match
138146
case dfVal: DFConditional.Header if dfVal.isAnonymous && dfVal.dfType != DFUnit =>
139147
val isReferencedByIdent =
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
package StagesSpec
2+
3+
import dfhdl.*
4+
import dfhdl.compiler.stages.explicitCondExprAssign
5+
// scalafmt: { align.tokens = [{code = "<>"}, {code = "="}, {code = "=>"}, {code = ":="}]}
6+
7+
class ExplicitCondExprAssignSpec extends StageSpec(stageCreatesUnrefAnons = true):
8+
test("Conditional expression assignment") {
9+
class ID extends DFDesign:
10+
val x = SInt(16) <> IN
11+
val y = SInt(16) <> OUT
12+
val z = SInt(16) <> VAR
13+
z := (
14+
if (x > 0) 5
15+
else if (x < 0) x + 1
16+
else x
17+
)
18+
val z2 = SInt(16) <> VAR
19+
z2 := z match
20+
case 1 | 2 => 17
21+
case _ => z + 12
22+
y := z2
23+
val id = (new ID).explicitCondExprAssign
24+
assertCodeString(
25+
id,
26+
"""|class ID extends DFDesign:
27+
| val x = SInt(16) <> IN
28+
| val y = SInt(16) <> OUT
29+
| val z = SInt(16) <> VAR
30+
| if (x > sd"16'0") z := sd"16'5"
31+
| else if (x < sd"16'0") z := x + sd"16'1"
32+
| else z := x
33+
| val z2 = SInt(16) <> VAR
34+
| z match
35+
| case sd"16'1" | sd"16'2" => z2 := sd"16'17"
36+
| case _ => z2 := z + sd"16'12"
37+
| end match
38+
| y := z2
39+
|end ID
40+
|""".stripMargin
41+
)
42+
}
43+
test("Nested conditional expression assignment") {
44+
class ID extends DFDesign:
45+
val x = SInt(16) <> IN
46+
val y = SInt(16) <> OUT
47+
val z = SInt(16) <> VAR
48+
z := (
49+
if (x > 0)
50+
if (x > 5) 5
51+
else -5
52+
else if (x < 0) x + 1
53+
else x
54+
)
55+
val z2 = SInt(16) <> VAR
56+
z2 := z match
57+
case 1 | 2 =>
58+
val zz: SInt[4] <> VAL = z match
59+
case 1 => 5
60+
case 2 => 3
61+
if (x < 11) zz + 3
62+
else zz
63+
case _ => z + 12
64+
y := z
65+
end ID
66+
val id = (new ID).explicitCondExprAssign
67+
assertCodeString(
68+
id,
69+
"""|class ID extends DFDesign:
70+
| val x = SInt(16) <> IN
71+
| val y = SInt(16) <> OUT
72+
| val z = SInt(16) <> VAR
73+
| if (x > sd"16'0")
74+
| if (x > sd"16'5") z := sd"16'5"
75+
| else z := sd"16'-5"
76+
| else if (x < sd"16'0") z := x + sd"16'1"
77+
| else z := x
78+
| end if
79+
| val z2 = SInt(16) <> VAR
80+
| z match
81+
| case sd"16'1" | sd"16'2" =>
82+
| val zz: SInt[4] <> VAL =
83+
| z match
84+
| case sd"16'1" => sd"4'5"
85+
| case sd"16'2" => sd"4'3"
86+
| end match
87+
| if (x < sd"16'11") z2 := (zz + sd"4'3").resize(16)
88+
| else z2 := zz.resize(16)
89+
| case _ => z2 := z + sd"16'12"
90+
| end match
91+
| y := z
92+
|end ID""".stripMargin
93+
)
94+
}
95+
96+
test("AES xtime example") {
97+
class xtime extends DFDesign:
98+
val lhs = Bits(8) <> IN
99+
val shifted = lhs << 1
100+
val o = Bits(8) <> OUT
101+
o <> ((
102+
if (lhs(7)) shifted ^ h"1b"
103+
else shifted
104+
): Bits[8] <> VAL)
105+
end xtime
106+
val id = (new xtime).explicitCondExprAssign
107+
assertCodeString(
108+
id,
109+
"""|class xtime extends DFDesign:
110+
| val lhs = Bits(8) <> IN
111+
| val shifted = lhs << 1
112+
| val o = Bits(8) <> OUT
113+
| if (lhs(7)) o := shifted ^ h"1b"
114+
| else o := shifted
115+
|end xtime""".stripMargin
116+
)
117+
}
118+
119+
test("LRShiftFlat example") {
120+
enum ShiftDir extends Encode:
121+
case Left, Right
122+
123+
class LRShiftFlat(
124+
val width: Int <> CONST = 8
125+
) extends RTDesign:
126+
val iBits = Bits(width) <> IN
127+
val shift = UInt.until(width) <> IN
128+
val dir = ShiftDir <> IN
129+
val oBits = Bits(width) <> OUT
130+
oBits := dir match
131+
case ShiftDir.Left => iBits << shift
132+
case ShiftDir.Right => iBits >> shift
133+
end LRShiftFlat
134+
val id = (new LRShiftFlat).explicitCondExprAssign
135+
assertCodeString(
136+
id,
137+
"""|enum ShiftDir(val value: UInt[1] <> CONST) extends Encode.Manual(1):
138+
| case Left extends ShiftDir(d"1'0")
139+
| case Right extends ShiftDir(d"1'1")
140+
|
141+
|class LRShiftFlat(val width: Int <> CONST = 8) extends RTDesign:
142+
| val iBits = Bits(width) <> IN
143+
| val shift = UInt(clog2(width)) <> IN
144+
| val dir = ShiftDir <> IN
145+
| val oBits = Bits(width) <> OUT
146+
| dir match
147+
| case ShiftDir.Left => oBits := iBits << shift.toInt
148+
| case ShiftDir.Right => oBits := iBits >> shift.toInt
149+
| end match
150+
|end LRShiftFlat""".stripMargin
151+
)
152+
}
153+
end ExplicitCondExprAssignSpec

compiler/stages/src/test/scala/StagesSpec/ExplicitNamedVarsSpec.scala

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -131,12 +131,9 @@ class ExplicitNamedVarsSpec extends StageSpec:
131131
| val shifted = Bits(8) <> VAR
132132
| shifted := lhs << 1
133133
| val o = Bits(8) <> OUT
134-
| val o_part = Bits(8) <> VAR
135-
| if (lhs(7)) o_part := shifted ^ h"1b"
136-
| else o_part := shifted
137-
| o <> o_part
138-
|end xtime
139-
|""".stripMargin
134+
| if (lhs(7)) o := shifted ^ h"1b"
135+
| else o := shifted
136+
|end xtime""".stripMargin
140137
)
141138
}
142139

compiler/stages/src/test/scala/StagesSpec/NamedSelectionSpec.scala

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@ class NamedSelectionSpec extends StageSpec(stageCreatesUnrefAnons = true):
1111
val i = Byte <> IN
1212
val o = Byte <> OUT
1313
val z = Byte <> OUT
14-
o := ((if (c) i else i): Byte <> VAL)
15-
z := ((i match
14+
o := (if (c) i else i)
15+
o := i | ((if (c) i else i): Byte <> VAL)
16+
z := i match
1617
case all(0) => i
1718
case _ => i
18-
): Byte <> VAL)
1919

2020
val id = (new Mux).verilogNamedSelection
2121
assertCodeString(
@@ -25,16 +25,20 @@ class NamedSelectionSpec extends StageSpec(stageCreatesUnrefAnons = true):
2525
| val i = Bits(8) <> IN
2626
| val o = Bits(8) <> OUT
2727
| val z = Bits(8) <> OUT
28+
| o := ((
29+
| if (c) i
30+
| else i
31+
| ): Bits[8] <> VAL)
2832
| val o_part: Bits[8] <> VAL =
2933
| if (c) i
3034
| else i
31-
| o := o_part
32-
| val z_part: Bits[8] <> VAL =
35+
| o := i | o_part
36+
| z := ((
3337
| i match
3438
| case h"00" => i
3539
| case _ => i
3640
| end match
37-
| z := z_part
41+
| ): Bits[8] <> VAL)
3842
|end Mux
3943
|""".stripMargin
4044
)

core/src/main/scala/dfhdl/compiler/patching/Patch.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -439,13 +439,13 @@ extension (db: DB)
439439
println(members.map(m => s"${m.hashString}: $m").mkString("\n"))
440440
println("----------------------------------------------------------------------------")
441441
println("refTable:")
442-
println(refTable.mkString("\n"))
442+
println(refTable.toList.sortBy(_._1.hashString).mkString("\n"))
443443
println("----------------------------------------------------------------------------")
444444
println("patchedMembers:")
445445
println(patchedMembers.map(m => s"${m.hashString}: $m").mkString("\n"))
446446
println("----------------------------------------------------------------------------")
447447
println("patchedRefTable:")
448-
println(patchedRefTable.mkString("\n"))
448+
println(patchedRefTable.toList.sortBy(_._1.hashString).mkString("\n"))
449449
println("----------------------------------------------------------------------------")
450450
}
451451
DB(patchedMembers, patchedRefTable, globalTags, srcFiles)

lib/src/test/resources/ref/AES.CipherSpecNoOpaques/verilog.v2001/hdl/xtime.v

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,14 @@
44

55
module xtime(
66
input wire [7:0] lhs,
7-
output wire [7:0] o
7+
output reg [7:0] o
88
);
99
`include "dfhdl_defs.vh"
1010
wire [7:0] shifted;
11-
reg [7:0] anon;
12-
assign o = anon;
1311
always @(*)
1412
begin
15-
if (lhs[7]) anon = shifted ^ 8'h1b;
16-
else anon = shifted;
13+
if (lhs[7]) o = shifted ^ 8'h1b;
14+
else o = shifted;
1715
end
1816
assign shifted = lhs << 1;
1917
endmodule

lib/src/test/resources/ref/AES.CipherSpecNoOpaques/verilog.v95/hdl/xtime.v

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,12 @@ module xtime(
88
);
99
`include "dfhdl_defs.vh"
1010
input wire [7:0] lhs;
11-
output wire [7:0] o;
11+
output reg [7:0] o;
1212
wire [7:0] shifted;
13-
reg [7:0] anon;
14-
assign o = anon;
1513
always @(shifted or lhs)
1614
begin
17-
if (lhs[7]) anon = shifted ^ 8'h1b;
18-
else anon = shifted;
15+
if (lhs[7]) o = shifted ^ 8'h1b;
16+
else o = shifted;
1917
end
2018
assign shifted = lhs << 1;
2119
endmodule

0 commit comments

Comments
 (0)