Skip to content

Commit 6682333

Browse files
committed
ReduceArg ops returns the arg index and its corresponding value.
1 parent e3849f6 commit 6682333

File tree

2 files changed

+32
-22
lines changed

2 files changed

+32
-22
lines changed

tests/extension/thread_/stream_reduce_arg_max/thread_stream_reduce_arg_max.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,24 +26,24 @@ def mkLed():
2626
strm = vthread.Stream(m, 'mystream', clk, rst)
2727
a = strm.source('a')
2828
size = strm.constant('size')
29-
max, max_valid = strm.ReduceArgMaxValid(a, size)
30-
strm.sink(max, 'max', when=max_valid, when_name='max_valid')
29+
index, _max, argmax_valid = strm.ReduceArgMaxValid(a, size)
30+
strm.sink(index, 'index', when=argmax_valid, when_name='argmax_valid')
3131

3232
def comp_stream(size, offset):
3333
strm.set_source('a', ram_a, offset, size)
3434
strm.set_constant('size', size)
35-
strm.set_sink('max', ram_b, offset, 1)
35+
strm.set_sink('index', ram_b, offset, 1)
3636
strm.run()
3737
strm.join()
3838

3939
def comp_sequential(size, offset):
4040
index = 0
41-
max = 0
41+
_max = 0
4242
for i in range(size):
4343
a = ram_a.read(i + offset)
44-
if a > max:
44+
if a > _max:
4545
index = i
46-
max = a
46+
_max = a
4747
ram_b.write(offset, index)
4848

4949
def check(size, offset_stream, offset_seq):

veriloggen/stream/stypes.py

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2610,6 +2610,7 @@ def __init__(self, right, size=None, initval=0,
26102610
enable=None, reset=None, width=32, signed=True):
26112611
_Accumulator.__init__(self, right, size, initval,
26122612
enable, reset, width, signed)
2613+
self.graph_label = 'ReduceAdd'
26132614

26142615

26152616
class ReduceSub(_Accumulator):
@@ -2619,6 +2620,7 @@ def __init__(self, right, size=None, initval=0,
26192620
enable=None, reset=None, width=32, signed=True):
26202621
_Accumulator.__init__(self, right, size, initval,
26212622
enable, reset, width, signed)
2623+
self.graph_label = 'ReduceSub'
26222624

26232625

26242626
class ReduceMul(_Accumulator):
@@ -2629,6 +2631,7 @@ def __init__(self, right, size=None, initval=0,
26292631
enable=None, reset=None, width=32, signed=True):
26302632
_Accumulator.__init__(self, right, size, initval,
26312633
enable, reset, width, signed)
2634+
self.graph_label = 'ReduceMul'
26322635

26332636

26342637
class ReduceDiv(_Accumulator):
@@ -2640,6 +2643,7 @@ def __init__(self, right, size=None, initval=0,
26402643
raise NotImplementedError()
26412644
_Accumulator.__init__(self, right, size, initval,
26422645
enable, reset, width, signed)
2646+
self.graph_label = 'ReduceDiv'
26432647

26442648

26452649
class ReduceMax(_Accumulator):
@@ -2649,6 +2653,7 @@ def __init__(self, right, size=None, initval=0,
26492653
enable=None, reset=None, width=32, signed=True):
26502654
_Accumulator.__init__(self, right, size, initval,
26512655
enable, reset, width, signed)
2656+
self.graph_label = 'ReduceMax'
26522657

26532658

26542659
class ReduceMin(_Accumulator):
@@ -2658,6 +2663,7 @@ def __init__(self, right, size=None, initval=0,
26582663
enable=None, reset=None, width=32, signed=True):
26592664
_Accumulator.__init__(self, right, size, initval,
26602665
enable, reset, width, signed)
2666+
self.graph_label = 'ReduceMin'
26612667

26622668

26632669
class ReduceCustom(_Accumulator):
@@ -3444,45 +3450,49 @@ def enable(self):
34443450
def ReduceArgMax(right, size=None, initval=0,
34453451
enable=None, reset=None, width=32, signed=True):
34463452

3447-
reduce_max = ReduceMax(right, size, initval,
3448-
enable, reset, width, signed)
3453+
_max = ReduceMax(right, size, initval,
3454+
enable, reset, width, signed)
34493455
counter = Counter(size, control=right, enable=enable, reset=reset)
3450-
update = NotEq(reduce_max, reduce_max.prev(1))
3456+
update = NotEq(_max, _max.prev(1))
34513457
update.latency = 0
3452-
return Predicate(counter, update)
3458+
index = Predicate(counter, update)
3459+
return index, _max
34533460

34543461

34553462
def ReduceArgMin(right, size=None, initval=0,
34563463
enable=None, reset=None, width=32, signed=True):
34573464

3458-
reduce_min = ReduceMin(right, size, initval,
3459-
enable, reset, width, signed)
3465+
_min = ReduceMin(right, size, initval,
3466+
enable, reset, width, signed)
34603467
counter = Counter(size, control=right, enable=enable, reset=reset)
3461-
update = NotEq(reduce_min, reduce_min.prev(1))
3468+
update = NotEq(_min, reduce_min.prev(1))
34623469
update.latency = 0
3463-
return Predicate(counter, update)
3470+
index = Predicate(counter, update)
3471+
return index, _min
34643472

34653473

34663474
def ReduceArgMaxValid(right, size=None, initval=0,
34673475
enable=None, reset=None, width=32, signed=True):
34683476

3469-
reduce_max, valid = ReduceMaxValid(right, size, initval,
3470-
enable, reset, width, signed)
3477+
_max, valid = ReduceMaxValid(right, size, initval,
3478+
enable, reset, width, signed)
34713479
counter = Counter(size, control=right, enable=enable, reset=reset)
3472-
update = NotEq(reduce_max, reduce_max.prev(1))
3480+
update = NotEq(_max, _max.prev(1))
34733481
update.latency = 0
3474-
return Predicate(counter, update), valid
3482+
index = Predicate(counter, update)
3483+
return index, _max, valid
34753484

34763485

34773486
def ReduceArgMinValid(right, size=None, initval=0,
34783487
enable=None, reset=None, width=32, signed=True):
34793488

3480-
reduce_min, valid = ReduceMinValid(right, size, initval,
3481-
enable, reset, width, signed)
3489+
_min, valid = ReduceMinValid(right, size, initval,
3490+
enable, reset, width, signed)
34823491
counter = Counter(size, control=right, enable=enable, reset=reset)
3483-
update = NotEq(reduce_min, reduce_min.prev(1))
3492+
update = NotEq(_min, _min.prev(1))
34843493
update.latency = 0
3485-
return Predicate(counter, update), valid
3494+
index = Predicate(counter, update)
3495+
return index, _min, valid
34863496

34873497

34883498
def make_condition(*cond, **kwargs):

0 commit comments

Comments
 (0)