Skip to content

Commit 5fdd569

Browse files
committed
Fixed case example
1 parent 03cb834 commit 5fdd569

File tree

1 file changed

+24
-19
lines changed

1 file changed

+24
-19
lines changed

pycde_example/case_example.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import pycde
22

3-
from pycde import Clock, Module, Reset, Input, Output, generator, ir
3+
from pycde import Clock, Module, Reset, Input, Output, generator, ir
44
from pycde.types import Bits
5-
from pycde.circt.dialects import sv
6-
from pycde.circt.ir import IntegerType, IntegerAttr, InsertionPoint
5+
from pycde.circt.dialects import sv, hw
76
from pycde import support
7+
from pycde.signals import _FromCirctValue
88

99
def unknown_location():
1010
return ir.Location.unknown()
@@ -24,30 +24,35 @@ class CaseExample(Module):
2424

2525
@generator
2626
def construct(ports):
27+
i32_type = ir.IntegerType.get_signless(32)
28+
result_reg = sv.RegOp(hw.InOutType.get(i32_type), name="result_reg")
29+
2730
al = sv.AlwaysCombOp()
2831
al.body.blocks.append()
29-
with InsertionPoint(al.body.blocks[0]):
30-
i6 = IntegerType.get_signless(6)
32+
with ir.InsertionPoint(al.body.blocks[0]):
33+
case_conditions = [sum( ((n >> i) & 1) << (2*i) for i in range(n.bit_length()) ) for n in range(80)]
34+
case_values = list(range(81))[::-1]
35+
36+
# 创建case patterns
37+
i6 = ir.IntegerType.get_signless(max(case_conditions).bit_length() + 1)
38+
case_patterns = [ir.IntegerAttr.get(i6, cond) for cond in case_conditions]
39+
case_patterns.append(ir.UnitAttr.get()) # default 分支
40+
3141
case_op = sv.CaseOp(
3242
cond=ports.data_i.value,
33-
casePatterns=[
34-
IntegerAttr.get(i6, 0), # case 0
35-
IntegerAttr.get(i6, 1), # case 1
36-
IntegerAttr.get(i6, 4), # case 2
37-
IntegerAttr.get(i6, 5), # case 3
38-
IntegerAttr.get(i6, 16), # case 4
39-
IntegerAttr.get(i6, 17), # case 5
40-
IntegerAttr.get(i6, 20), # case 6
41-
IntegerAttr.get(i6, 21), # case 7
42-
ir.UnitAttr.get(), # default 分支
43-
],
44-
num_caseRegions=9,
43+
casePatterns=case_patterns,
44+
num_caseRegions=len(case_patterns),
4545
)
46+
47+
# 为每个case分支赋值
4648
for i in range(len(case_op.caseRegions)):
4749
case_op.caseRegions[i].blocks.append()
48-
with InsertionPoint(case_op.caseRegions[i].blocks[0]):
50+
with ir.InsertionPoint(case_op.caseRegions[i].blocks[0]):
4951
sv.VerbatimOp(ir.StringAttr.get(f"// value = 32'h{i};\n"), [])
50-
ports.data_o = Bits(32)(0)
52+
sv.BPAssignOp(result_reg, Bits(32)(case_values[i]).value)
53+
54+
# 从寄存器读取值并赋给输出端口
55+
ports.data_o = _FromCirctValue(sv.ReadInOutOp(result_reg).result)
5156

5257
if __name__ == "__main__":
5358

0 commit comments

Comments
 (0)