Skip to content

Commit 3560b21

Browse files
committed
reduce latency for sra_round
1 parent 634ce50 commit 3560b21

File tree

3 files changed

+59
-94
lines changed

3 files changed

+59
-94
lines changed

tests/extension/stream_/sra_round/stream_sra_round.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import veriloggen.stream as stream
1212

1313
from decimal import Decimal, ROUND_HALF_UP, ROUND_HALF_EVEN
14-
#from pprint import pprint
14+
from pprint import pprint
1515

1616
def mkMain():
1717
# input variiable
@@ -27,6 +27,8 @@ def mkMain():
2727
st = stream.Stream(z)
2828
m = st.to_module('main')
2929

30+
#st.draw_graph()
31+
3032
return m, st.pipeline_depth()
3133

3234

@@ -146,14 +148,14 @@ def mkTest(numports=8):
146148
sim = simulation.Simulator(test)
147149
rslt = sim.run() # display=False
148150
#rslt = sim.run(display=True)
149-
print(rslt)
151+
#print(rslt)
150152

151153
vx = list(map(lambda x: int(str.split(x,"=")[1]), filter(lambda x: "xdata" in x , str.split(rslt, "\n"))))
152154
vy = list(map(lambda x: int(str.split(x,"=")[1]), filter(lambda x: "ydata" in x , str.split(rslt, "\n"))))
153155
vz = list(map(lambda x: int(str.split(x,"=")[1]), filter(lambda x: "zdata" in x , str.split(rslt, "\n"))))
154156
ez = list(map(lambda x,y: int( Decimal(str(x/(2.0**y))).quantize(Decimal('0'), rounding=ROUND_HALF_UP)), vx,vy))
155157

156-
#pprint(list(zip(lx,ly,lz,ez)))
158+
pprint(list(zip(vx,vy,vz,ez)))
157159
assert(all(map(lambda v, e: v==e, vz, ez)))
158160

159161
# launch waveform viewer (GTKwave)

tests/extension/stream_/sra_round/test_stream_sra_round.py

Lines changed: 32 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -868,12 +868,12 @@
868868
recv_fsm <= recv_fsm_2;
869869
end
870870
recv_fsm_2: begin
871-
if(recv_count < 5) begin
871+
if(recv_count < 0) begin
872872
recv_count <= recv_count + 1;
873873
end else begin
874874
recv_count <= 0;
875875
end
876-
if(recv_count >= 5) begin
876+
if(recv_count >= 0) begin
877877
recv_fsm <= recv_fsm_3;
878878
end
879879
end
@@ -1338,100 +1338,45 @@
13381338
output signed [32-1:0] zdata
13391339
);
13401340
1341-
reg [1-1:0] _pointer_data_2;
1342-
reg [5-1:0] _slice_data_8;
1343-
reg [1-1:0] _eq_data_24;
1341+
wire [1-1:0] _pointer_data_2;
1342+
assign _pointer_data_2 = xdata[6'sd31];
1343+
wire [5-1:0] _slice_data_6;
1344+
assign _slice_data_6 = ydata[4'd4:1'd0];
1345+
wire [5-1:0] _minus_data_7;
1346+
assign _minus_data_7 = _slice_data_6 - 2'sd1;
1347+
wire signed [34-1:0] _sll_data_10;
1348+
assign _sll_data_10 = 2'sd1 << _minus_data_7;
1349+
wire signed [2-1:0] _cond_data_18;
1350+
assign _cond_data_18 = (_pointer_data_2)? -2'sd1 : 1'sd0;
1351+
wire signed [33-1:0] _plus_data_19;
1352+
assign _plus_data_19 = xdata + _sll_data_10;
1353+
wire signed [33-1:0] _plus_data_20;
1354+
assign _plus_data_20 = _plus_data_19 + _cond_data_18;
1355+
wire signed [32-1:0] _sra_data_21;
1356+
assign _sra_data_21 = _plus_data_20 >>> ydata;
1357+
reg [1-1:0] _eq_data_23;
1358+
reg signed [32-1:0] __delay_data_25;
13441359
reg signed [32-1:0] __delay_data_26;
1345-
reg signed [32-1:0] __delay_data_31;
1346-
reg [5-1:0] _minus_data_9;
1347-
reg signed [2-1:0] _cond_data_19;
1348-
reg signed [32-1:0] __delay_data_27;
1349-
reg signed [32-1:0] __delay_data_32;
1350-
reg [1-1:0] __delay_data_36;
1351-
reg signed [34-1:0] _sll_data_11;
1352-
reg signed [32-1:0] __delay_data_28;
1353-
reg signed [2-1:0] __delay_data_29;
1354-
reg signed [32-1:0] __delay_data_33;
1355-
reg [1-1:0] __delay_data_37;
1356-
reg signed [34-1:0] _plus_data_20;
1357-
reg signed [2-1:0] __delay_data_30;
1358-
reg signed [32-1:0] __delay_data_34;
1359-
reg [1-1:0] __delay_data_38;
1360-
reg signed [32-1:0] __delay_data_41;
1361-
reg signed [34-1:0] _plus_data_21;
1362-
reg signed [32-1:0] __delay_data_35;
1363-
reg [1-1:0] __delay_data_39;
1364-
reg signed [32-1:0] __delay_data_42;
1365-
reg signed [32-1:0] _sra_data_22;
1366-
reg [1-1:0] __delay_data_40;
1367-
reg signed [32-1:0] __delay_data_43;
1368-
reg signed [32-1:0] _cond_data_25;
1369-
assign zdata = _cond_data_25;
1360+
reg signed [32-1:0] _cond_data_24;
1361+
assign zdata = _cond_data_24;
13701362
13711363
always @(posedge CLK) begin
13721364
if(RST) begin
1373-
_pointer_data_2 <= 0;
1374-
_slice_data_8 <= 0;
1375-
_eq_data_24 <= 0;
1365+
_eq_data_23 <= 0;
1366+
__delay_data_25 <= 0;
13761367
__delay_data_26 <= 0;
1377-
__delay_data_31 <= 0;
1378-
_minus_data_9 <= 0;
1379-
_cond_data_19 <= 0;
1380-
__delay_data_27 <= 0;
1381-
__delay_data_32 <= 0;
1382-
__delay_data_36 <= 0;
1383-
_sll_data_11 <= 0;
1384-
__delay_data_28 <= 0;
1385-
__delay_data_29 <= 0;
1386-
__delay_data_33 <= 0;
1387-
__delay_data_37 <= 0;
1388-
_plus_data_20 <= 0;
1389-
__delay_data_30 <= 0;
1390-
__delay_data_34 <= 0;
1391-
__delay_data_38 <= 0;
1392-
__delay_data_41 <= 0;
1393-
_plus_data_21 <= 0;
1394-
__delay_data_35 <= 0;
1395-
__delay_data_39 <= 0;
1396-
__delay_data_42 <= 0;
1397-
_sra_data_22 <= 0;
1398-
__delay_data_40 <= 0;
1399-
__delay_data_43 <= 0;
1400-
_cond_data_25 <= 0;
1368+
_cond_data_24 <= 0;
14011369
end else begin
1402-
_pointer_data_2 <= xdata[6'sd31];
1403-
_slice_data_8 <= ydata[4'd4:1'd0];
1404-
_eq_data_24 <= ydata == 1'sd0;
1405-
__delay_data_26 <= xdata;
1406-
__delay_data_31 <= ydata;
1407-
_minus_data_9 <= _slice_data_8 - 2'sd1;
1408-
_cond_data_19 <= (_pointer_data_2)? -2'sd1 : 1'sd0;
1409-
__delay_data_27 <= __delay_data_26;
1410-
__delay_data_32 <= __delay_data_31;
1411-
__delay_data_36 <= _eq_data_24;
1412-
_sll_data_11 <= 2'sd1 << _minus_data_9;
1413-
__delay_data_28 <= __delay_data_27;
1414-
__delay_data_29 <= _cond_data_19;
1415-
__delay_data_33 <= __delay_data_32;
1416-
__delay_data_37 <= __delay_data_36;
1417-
_plus_data_20 <= __delay_data_28 + _sll_data_11;
1418-
__delay_data_30 <= __delay_data_29;
1419-
__delay_data_34 <= __delay_data_33;
1420-
__delay_data_38 <= __delay_data_37;
1421-
__delay_data_41 <= __delay_data_28;
1422-
_plus_data_21 <= _plus_data_20 + __delay_data_30;
1423-
__delay_data_35 <= __delay_data_34;
1424-
__delay_data_39 <= __delay_data_38;
1425-
__delay_data_42 <= __delay_data_41;
1426-
_sra_data_22 <= _plus_data_21 >>> __delay_data_35;
1427-
__delay_data_40 <= __delay_data_39;
1428-
__delay_data_43 <= __delay_data_42;
1429-
_cond_data_25 <= (__delay_data_40)? __delay_data_43 : _sra_data_22;
1370+
_eq_data_23 <= ydata == 1'sd0;
1371+
__delay_data_25 <= xdata;
1372+
__delay_data_26 <= _sra_data_21;
1373+
_cond_data_24 <= (_eq_data_23)? __delay_data_25 : __delay_data_26;
14301374
end
14311375
end
14321376
14331377
14341378
endmodule
1379+
14351380
"""
14361381

14371382

@@ -1460,7 +1405,7 @@ def test():
14601405
vz = list(map(lambda x: int(str.split(x,"=")[1]), filter(lambda x: "zdata" in x , str.split(rslt, "\n"))))
14611406
ez = list(map(lambda x,y: int( Decimal(str(x/(2.0**y))).quantize(Decimal('0'), rounding=ROUND_HALF_UP)), vx,vy))
14621407

1463-
#pprint(list(zip(lx,ly,lz,ez)))
1408+
#pprint(list(zip(vx,vy,vz,ez)))
14641409

1465-
assert(all(map(lambda x ,y: x == y, vz,ez)))
1410+
assert(all(map(lambda x ,y: x == y, vz, ez)))
14661411

veriloggen/stream/stypes.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2169,15 +2169,33 @@ def AverageRound(*args):
21692169

21702170
def SraRound(left, right):
21712171
msb = left[-1]
2172+
msb.latency = 0
2173+
2174+
if isinstance(right, int):
2175+
rounder = Sll(Int(1), right - 1)
2176+
else:
2177+
right_slice = right[0:int(log(right.width, 2))]
2178+
right_slice.latency = 0
2179+
right_slice = right_slice - 1
2180+
right_slice.latency = 0
2181+
rounder = Sll(Int(1), right_slice)
2182+
2183+
rounder.latency = 0
2184+
rounder_sign = Mux(msb, Int(-1), Int(0))
2185+
rounder_sign.latency = 0
2186+
2187+
# if left.width < right
2188+
# raise ValueError("Shift amount of SraRound operator must be less than val bit width")
21722189

2173-
pre_round = Int(0)
2190+
pre_round = left + rounder
21742191
pre_round.width = left.width + 1
2192+
pre_round.latency = 0
2193+
pre_round = pre_round + rounder_sign
2194+
pre_round.latency = 0
21752195

2176-
rounder = Sll(Int(1), right[0:int(log(right.width, 2))] - 1)
2177-
rounder_sign = Mux(msb, Int(-1), Int(0))
2178-
pre_round = left + rounder + rounder_sign
21792196
shifted = Sra(pre_round, right)
21802197
shifted.width = left.width
2198+
shifted.latency = 0
21812199

21822200
return Mux(right == Int(0), left, shifted)
21832201

0 commit comments

Comments
 (0)