Skip to content

Commit 7631832

Browse files
committed
gateware: fix HalfBandInterpolator backpressure issues
1 parent 29bfc3b commit 7631832

File tree

2 files changed

+173
-107
lines changed

2 files changed

+173
-107
lines changed
613 Bytes
Binary file not shown.

firmware/fpga/dsp/fir.py

Lines changed: 173 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from math import ceil, log2
88

99
from amaranth import Module, Signal, Mux, DomainRenamer
10-
from amaranth.lib import wiring, stream, data, memory
10+
from amaranth.lib import wiring, stream, data, memory, fifo
1111
from amaranth.lib.wiring import In, Out
1212
from amaranth.utils import bits_for
1313

@@ -173,9 +173,17 @@ def elaborate(self, platform):
173173
delay = arm1_taps.index(1)
174174

175175
# Arms
176-
m.submodules.fir0 = fir0 = FIRFilter(arm0_taps, shape=self.data_shape, shape_out=self.shape_out, always_ready=always_ready, num_channels=self.num_channels)
177-
m.submodules.fir1 = fir1 = Delay(delay, shape=self.data_shape, always_ready=always_ready, num_channels=self.num_channels)
178-
arms = [fir0, fir1]
176+
m.submodules.fir = fir = FIRFilter(arm0_taps, shape=self.data_shape, shape_out=self.shape_out, always_ready=always_ready, num_channels=self.num_channels)
177+
m.submodules.dly = dly = Delay(delay, shape=self.data_shape, always_ready=always_ready, num_channels=self.num_channels)
178+
m.submodules.dly_fifo = dly_fifo = fifo.SyncFIFOBuffered(width=self.num_channels*self.data_shape.as_shape().width, depth=1)
179+
arms = [fir, dly]
180+
181+
m.d.comb += [
182+
dly_fifo.w_data.eq(dly.output.p),
183+
dly_fifo.w_en.eq(dly.output.valid),
184+
]
185+
if not dly.output.signature.always_ready:
186+
m.d.comb += dly.output.ready.eq(dly_fifo.w_rdy)
179187

180188
with m.FSM():
181189

@@ -198,7 +206,6 @@ def elaborate(self, platform):
198206
m.next = "BYPASS"
199207

200208
# Input
201-
202209
for i, arm in enumerate(arms):
203210
m.d.comb += arm.input.payload.eq(self.input.payload)
204211
m.d.comb += arm.input.valid.eq(self.input.valid & arms[i^1].input.ready)
@@ -211,29 +218,25 @@ def elaborate(self, platform):
211218
arm_index = Signal()
212219

213220
# Output buffers for each arm.
214-
arm_outputs = [arm.output for arm in arms]
215-
if self.output.signature.always_ready:
216-
buffers = [stream.Signature(arm.payload.shape()).create() for arm in arm_outputs]
217-
for arm, buf in zip(arm_outputs, buffers):
218-
with m.If(~buf.valid | buf.ready):
219-
if not arm.signature.always_ready:
220-
m.d.comb += arm.ready.eq(1)
221-
m.d.sync += buf.valid.eq(arm.valid)
222-
with m.If(arm.valid):
223-
m.d.sync += buf.payload.eq(arm.payload)
224-
arm_outputs = buffers
221+
r_data_cast = data.ArrayLayout(self.data_shape, self.num_channels)(dly_fifo.r_data)
225222

226223
with m.If(~self.output.valid | self.output.ready):
227224
with m.Switch(arm_index):
228-
for i, arm in enumerate(arm_outputs):
229-
with m.Case(i):
230-
for c in range(self.num_channels):
231-
m.d.sync += self.output.payload[c].eq(arm.payload[c])
232-
m.d.sync += self.output.valid.eq(arm.valid)
233-
if not arm.signature.always_ready:
234-
m.d.comb += arm.ready.eq(1)
235-
with m.If(arm.valid):
236-
m.d.sync += arm_index.eq(arm_index ^ 1)
225+
with m.Case(0):
226+
for c in range(self.num_channels):
227+
m.d.sync += self.output.payload[c].eq(fir.output.payload[c])
228+
m.d.sync += self.output.valid.eq(fir.output.valid)
229+
if not fir.output.signature.always_ready:
230+
m.d.comb += fir.output.ready.eq(1)
231+
with m.If(fir.output.valid):
232+
m.d.sync += arm_index.eq(1)
233+
with m.Case(1):
234+
for c in range(self.num_channels):
235+
m.d.sync += self.output.payload[c].eq(r_data_cast[c])
236+
m.d.sync += self.output.valid.eq(dly_fifo.r_rdy)
237+
m.d.comb += dly_fifo.r_en.eq(1)
238+
with m.If(dly_fifo.r_rdy):
239+
m.d.sync += arm_index.eq(0)
237240

238241
if self._domain != "sync":
239242
m = DomainRenamer(self._domain)(m)
@@ -439,24 +442,26 @@ def _generate_samples(self, count, width, f_width=0):
439442
return samples / (1 << f_width)
440443
return samples
441444

442-
def _filter(self, dut, samples, count, num_channels=1, outfile=None, empty_cycles=0):
445+
def _filter(self, dut, samples, count, num_channels=1, outfile=None, empty_cycles=0, empty_ready_cycles=0):
443446

444447
async def input_process(ctx):
445448
if hasattr(dut, "enable"):
446449
ctx.set(dut.enable, 1)
447-
await ctx.tick()
448-
ctx.set(dut.input.valid, 1)
449-
for sample in samples:
450+
await ctx.tick()
451+
452+
for i, sample in enumerate(samples):
450453
if num_channels > 1:
451454
ctx.set(dut.input.payload, [s.item() for s in sample])
452455
else:
453-
ctx.set(dut.input.payload, [sample.item()])
456+
if isinstance(dut.input.payload.shape(), data.ArrayLayout):
457+
ctx.set(dut.input.payload, [sample.item()])
458+
else:
459+
ctx.set(dut.input.payload, sample.item())
460+
ctx.set(dut.input.valid, 1)
454461
await ctx.tick().until(dut.input.ready)
462+
ctx.set(dut.input.valid, 0)
455463
if empty_cycles > 0:
456-
ctx.set(dut.input.valid, 0)
457464
await ctx.tick().repeat(empty_cycles)
458-
ctx.set(dut.input.valid, 1)
459-
ctx.set(dut.input.valid, 0)
460465

461466
filtered = []
462467
async def output_process(ctx):
@@ -467,7 +472,14 @@ async def output_process(ctx):
467472
if num_channels > 1:
468473
filtered.append([v.as_float() for v in payload])
469474
else:
470-
filtered.append(payload[0].as_float())
475+
if isinstance(payload.shape(), data.ArrayLayout):
476+
filtered.append(payload[0].as_float())
477+
else:
478+
filtered.append(payload.as_float())
479+
if empty_ready_cycles > 0:
480+
ctx.set(dut.output.ready, 0)
481+
await ctx.tick().repeat(empty_ready_cycles)
482+
ctx.set(dut.output.ready, 1)
471483
if not dut.output.signature.always_ready:
472484
ctx.set(dut.output.ready, 0)
473485

@@ -498,100 +510,154 @@ def test_filter(self):
498510
filtered_np = np.convolve(input_samples, taps).tolist()
499511

500512
# Simulate DUT
501-
dut = FIRFilter(taps, fixed.SQ(15, 0), always_ready=True)
502-
filtered = self._filter(dut, input_samples, len(input_samples))
513+
dut = FIRFilter(taps, shape=fixed.SQ(8, 0), always_ready=False)
514+
filtered = self._filter(dut, input_samples, len(input_samples), empty_ready_cycles=5)
503515

504516
self.assertListEqual(filtered_np[:len(filtered)], filtered)
505517

506518

507519
class TestHalfBandDecimator(_TestFilter):
508520

509-
def test_filter_no_backpressure(self):
510-
taps = [-1, 0, 9, 16, 9, 0, -1]
511-
taps = [ tap / 32 for tap in taps ]
512-
513-
num_samples = 1024
514-
input_width = 8
515-
samples_i_in = self._generate_samples(num_samples, input_width, f_width=7)
516-
samples_q_in = self._generate_samples(num_samples, input_width, f_width=7)
517-
518-
# Compute the expected result
519-
filtered_i_np = np.convolve(samples_i_in, taps)[1::2].tolist()
520-
filtered_q_np = np.convolve(samples_q_in, taps)[1::2].tolist()
521-
522-
# Simulate DUT
523-
dut = HalfBandDecimator(taps, data_shape=fixed.SQ(7), shape_out=fixed.SQ(0,16), always_ready=True)
524-
filtered = self._filter(dut, zip(samples_i_in, samples_q_in), len(samples_i_in) // 2, num_channels=2)
525-
filtered_i = [ x[0] for x in filtered ]
526-
filtered_q = [ x[1] for x in filtered ]
527-
528-
self.assertListEqual(filtered_i_np[:len(filtered_i)], filtered_i)
529-
self.assertListEqual(filtered_q_np[:len(filtered_q)], filtered_q)
530-
531-
def test_filter_with_spare_cycles(self):
532-
taps = [-1, 0, 9, 16, 9, 0, -1]
533-
taps = [ tap / 32 for tap in taps ]
534-
535-
num_samples = 1024
536-
input_width = 8
537-
samples_i_in = self._generate_samples(num_samples, input_width, f_width=7)
538-
samples_q_in = self._generate_samples(num_samples, input_width, f_width=7)
521+
def test_filter(self):
539522

540-
# Compute the expected result
541-
filtered_i_np = np.convolve(samples_i_in, taps)[1::2].tolist()
542-
filtered_q_np = np.convolve(samples_q_in, taps)[1::2].tolist()
523+
common_dut_options = dict(
524+
data_shape=fixed.SQ(7),
525+
shape_out=fixed.SQ(0,31),
526+
)
543527

544-
# Simulate DUT
545-
dut = HalfBandDecimator(taps, data_shape=fixed.SQ(7), shape_out=fixed.SQ(0,16), always_ready=True)
546-
filtered = self._filter(dut, zip(samples_i_in, samples_q_in), len(samples_i_in) // 2, num_channels=2, empty_cycles=3)
547-
filtered_i = [ x[0] for x in filtered ]
548-
filtered_q = [ x[1] for x in filtered ]
528+
taps0 = (np.array([-1, 0, 9, 16, 9, 0, -1]) / 32).tolist()
529+
taps1 = (np.array([-2, 0, 7, 0, -18, 0, 41, 0, -92, 0, 320, 512, 320, 0, -92, 0, 41, 0, -18, 0, 7, 0, -2]) / 1024).tolist()
530+
531+
532+
inputs = {
533+
534+
"test_filter_with_backpressure": {
535+
"num_samples": 1024,
536+
"dut_options": dict(**common_dut_options, always_ready=False, taps=taps0),
537+
"sim_opts": dict(empty_cycles=0),
538+
},
539+
540+
"test_filter_with_backpressure_and_empty_cycles": {
541+
"num_samples": 1024,
542+
"dut_options": dict(**common_dut_options, always_ready=False, taps=taps0),
543+
"sim_opts": dict(empty_cycles=3),
544+
},
545+
546+
"test_filter_with_backpressure_taps1": {
547+
"num_samples": 1024,
548+
"dut_options": dict(**common_dut_options, always_ready=False, taps=taps1),
549+
"sim_opts": dict(empty_cycles=0),
550+
},
551+
552+
"test_filter_no_backpressure_and_empty_cycles_taps1": {
553+
"num_samples": 1024,
554+
"dut_options": dict(**common_dut_options, always_ready=True, taps=taps0),
555+
"sim_opts": dict(empty_cycles=6),
556+
},
557+
558+
"test_filter_no_backpressure": {
559+
"num_samples": 1024,
560+
"dut_options": dict(**common_dut_options, always_ready=True, taps=taps1),
561+
"sim_opts": dict(empty_cycles=3),
562+
},
563+
}
564+
565+
for name, scenario in inputs.items():
549566

550-
self.assertListEqual(filtered_i_np[:len(filtered_i)], filtered_i)
551-
self.assertListEqual(filtered_q_np[:len(filtered_q)], filtered_q)
567+
with self.subTest(name):
568+
taps = scenario["dut_options"]["taps"]
569+
num_samples = scenario["num_samples"]
552570

553-
def test_filter_with_backpressure(self):
554-
taps = [-1, 0, 9, 16, 9, 0, -1]
555-
taps = [ tap / 32 for tap in taps ]
571+
input_width = 8
572+
samples_i_in = self._generate_samples(num_samples, input_width, f_width=7)
573+
samples_q_in = self._generate_samples(num_samples, input_width, f_width=7)
556574

557-
num_samples = 1024
558-
input_width = 8
559-
samples_i_in = self._generate_samples(num_samples, input_width, f_width=7)
560-
samples_q_in = self._generate_samples(num_samples, input_width, f_width=7)
575+
# Compute the expected result
576+
filtered_i_np = np.convolve(samples_i_in, taps)[1::2].tolist()
577+
filtered_q_np = np.convolve(samples_q_in, taps)[1::2].tolist()
561578

562-
# Compute the expected result
563-
filtered_i_np = np.convolve(samples_i_in, taps)[1::2].tolist()
564-
filtered_q_np = np.convolve(samples_q_in, taps)[1::2].tolist()
579+
# Simulate DUT
580+
dut = HalfBandDecimator(**scenario["dut_options"])
581+
filtered = self._filter(dut, zip(samples_i_in, samples_q_in), len(samples_i_in) // 2, num_channels=2, **scenario["sim_opts"])
582+
filtered_i = [ x[0] for x in filtered ]
583+
filtered_q = [ x[1] for x in filtered ]
565584

566-
# Simulate DUT
567-
dut = HalfBandDecimator(taps, data_shape=fixed.SQ(7), shape_out=fixed.SQ(0,16), always_ready=False)
568-
filtered = self._filter(dut, zip(samples_i_in, samples_q_in), len(samples_i_in) // 2, num_channels=2)
569-
filtered_i = [ x[0] for x in filtered ]
570-
filtered_q = [ x[1] for x in filtered ]
585+
self.assertListEqual(filtered_i_np[:len(filtered_i)], filtered_i)
586+
self.assertListEqual(filtered_q_np[:len(filtered_q)], filtered_q)
571587

572-
self.assertListEqual(filtered_i_np[:len(filtered_i)], filtered_i)
573-
self.assertListEqual(filtered_q_np[:len(filtered_q)], filtered_q)
574588

575589
class TestHalfBandInterpolator(_TestFilter):
576590

577591
def test_filter(self):
578-
taps = [-1, 0, 9, 16, 9, 0, -1]
579-
taps = [ tap / 32 for tap in taps ]
580-
num_samples = 1024
581-
input_width = 8
582-
input_samples = self._generate_samples(num_samples, input_width, f_width=7)
583592

584-
# Compute the expected result
585-
input_samples_pad = np.zeros(2*len(input_samples))
586-
input_samples_pad[0::2] = 2*input_samples # pad with zeros, adjust gain
587-
filtered_np = np.convolve(input_samples_pad, taps).tolist()
593+
common_dut_options = dict(
594+
data_shape=fixed.SQ(7),
595+
shape_out=fixed.SQ(1,16),
596+
)
588597

589-
# Simulate DUT
590-
dut = HalfBandInterpolator(taps, data_shape=fixed.SQ(0, 7), shape_out=fixed.SQ(0,16), always_ready=False)
591-
filtered = self._filter(dut, input_samples, len(input_samples) * 2)
598+
taps0 = (np.array([-1, 0, 9, 16, 9, 0, -1]) / 32).tolist()
599+
taps1 = (np.array([-2, 0, 7, 0, -18, 0, 41, 0, -92, 0, 320, 512, 320, 0, -92, 0, 41, 0, -18, 0, 7, 0, -2]) / 1024).tolist()
600+
601+
inputs = {
602+
603+
"test_filter_with_backpressure": {
604+
"num_samples": 1024,
605+
"dut_options": dict(**common_dut_options, always_ready=False, num_channels=2, taps=taps1),
606+
"sim_opts": dict(empty_cycles=0, empty_ready_cycles=0),
607+
},
608+
609+
"test_filter_with_backpressure_and_empty_cycles": {
610+
"num_samples": 1024,
611+
"dut_options": dict(**common_dut_options, num_channels=2, always_ready=False, taps=taps0),
612+
"sim_opts": dict(empty_ready_cycles=7, empty_cycles=3),
613+
},
614+
615+
"test_filter_with_backpressure_taps1": {
616+
"num_samples": 1024,
617+
"dut_options": dict(**common_dut_options, num_channels=2, always_ready=False, taps=taps1),
618+
"sim_opts": dict(empty_ready_cycles=7, empty_cycles=0),
619+
},
620+
621+
"test_filter_no_backpressure_and_empty_cycles_taps1": {
622+
"num_samples": 1024,
623+
"dut_options": dict(**common_dut_options, num_channels=2, always_ready=True, taps=taps0),
624+
"sim_opts": dict(empty_cycles=8),
625+
},
626+
627+
"test_filter_no_backpressure": {
628+
"num_samples": 1024,
629+
"dut_options": dict(**common_dut_options, num_channels=2, always_ready=True, taps=taps1),
630+
"sim_opts": dict(empty_cycles=16),
631+
},
592632

593-
self.assertListEqual(filtered_np[:len(filtered)], filtered)
633+
}
594634

635+
636+
for name, scenario in inputs.items():
637+
with self.subTest(name):
638+
taps = scenario["dut_options"]["taps"]
639+
num_samples = scenario["num_samples"]
640+
641+
input_width = 8
642+
samples_i_in = self._generate_samples(num_samples, input_width, f_width=7)
643+
samples_q_in = self._generate_samples(num_samples, input_width, f_width=7)
644+
645+
# Compute the expected result
646+
input_samples_pad = np.zeros(2*len(samples_i_in))
647+
input_samples_pad[0::2] = 2*samples_i_in # pad with zeros, adjust gain
648+
filtered_i_np = np.convolve(input_samples_pad, taps).tolist()
649+
input_samples_pad = np.zeros(2*len(samples_q_in))
650+
input_samples_pad[0::2] = 2*samples_q_in # pad with zeros, adjust gain
651+
filtered_q_np = np.convolve(input_samples_pad, taps).tolist()
652+
653+
# Simulate DUT
654+
dut = HalfBandInterpolator(**scenario["dut_options"])
655+
filtered = self._filter(dut, zip(samples_i_in, samples_q_in), len(samples_i_in) * 2, num_channels=2, **scenario["sim_opts"])
656+
filtered_i = [ x[0] for x in filtered ]
657+
filtered_q = [ x[1] for x in filtered ]
658+
659+
self.assertListEqual(filtered_i_np[:len(filtered_i)], filtered_i)
660+
self.assertListEqual(filtered_q_np[:len(filtered_q)], filtered_q)
595661

596662
if __name__ == "__main__":
597-
unittest.main()
663+
unittest.main()

0 commit comments

Comments
 (0)