Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 18 additions & 18 deletions amaranth_orchard/base/gpio.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
from amaranth import *
from amaranth import Module, unsigned
from amaranth.lib import wiring
from amaranth.lib.wiring import In, Out, flipped, connect

from amaranth_soc import csr

from chipflow_lib.platforms import BidirPinSignature

__all__ = ["GPIOPins", "GPIOPeripheral"]


class GPIOPins(wiring.PureInterface):
class Signature(wiring.Signature):
def __init__(self, width):
if width > 32:
raise ValueError(f"Pin width must be lesser than or equal to 32, not {width}")
self._width = width
super().__init__({
"o": Out(unsigned(width)),
"oe": Out(unsigned(width)),
"i": In(unsigned(width)),
"gpio": Out(BidirPinSignature(width, all_have_oe=True))
})

@property
Expand All @@ -28,6 +29,10 @@ def create(self, *, path=(), src_loc_at=0):
def __init__(self, width, *, path=(), src_loc_at=0):
super().__init__(self.Signature(width), path=path, src_loc_at=1 + src_loc_at)

@property
def width(self):
return self.signature.width


class GPIOPeripheral(wiring.Component):
class DO(csr.Register, access="rw"):
Expand All @@ -45,16 +50,13 @@ class DI(csr.Register, access="r"):
def __init__(self, width):
super().__init__({"pins": csr.Field(csr.action.R, unsigned(width))})

"""Simple GPIO peripheral.

All pins default to input at power up.
"""
def __init__(self, *, pins):
if len(pins.o) > 32:
raise ValueError(f"Pin width must be lesser than or equal to 32, not {len(pins.o)}")
def __init__(self, *, pins: GPIOPins):
"""Simple GPIO peripheral.

self.width = len(pins.o)
self.pins = pins
All pins default to input at power up.
"""
self.width = pins.width
self.pins = pins

regs = csr.Builder(addr_width=4, data_width=8)

Expand All @@ -75,10 +77,8 @@ def elaborate(self, platform):

connect(m, flipped(self.bus), self._bridge.bus)

m.d.comb += [
self.pins.o .eq(self._do.f.pins.data),
self.pins.oe.eq(self._oe.f.pins.data),
]
m.d.sync += self._di.f.pins.r_data.eq(self.pins.i)
m.d.comb += self.pins.gpio.o.eq(self._do.f.pins.data)
m.d.comb += self.pins.gpio.oe.eq(self._oe.f.pins.data)
m.d.comb += self._di.f.pins.r_data.eq(self.pins.gpio.i)

return m
12 changes: 7 additions & 5 deletions amaranth_orchard/io/uart.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from amaranth import *
from amaranth import Module, Signal, unsigned
from amaranth.lib import wiring
from amaranth.lib.wiring import In, Out, flipped, connect

from amaranth_soc import csr
from amaranth_stdio.serial import AsyncSerialRX, AsyncSerialTX

from chipflow_lib.platforms import OutputPinSignature, InputPinSignature


__all__ = ["UARTPins", "UARTPeripheral"]

Expand All @@ -13,8 +15,8 @@ class UARTPins(wiring.PureInterface):
class Signature(wiring.Signature):
def __init__(self):
super().__init__({
"tx_o": Out(1),
"rx_i": In(1),
"tx": Out(OutputPinSignature(1)),
"rx": Out(InputPinSignature(1)),
})

def create(self, *, path=(), src_loc_at=0):
Expand Down Expand Up @@ -80,7 +82,7 @@ def elaborate(self, platform):

m.submodules.tx = tx = AsyncSerialTX(divisor=self.init_divisor, divisor_bits=24)
m.d.comb += [
self.pins.tx_o.eq(tx.o),
self.pins.tx.o.eq(tx.o),
tx.data.eq(self._tx_data.f.val.w_data),
tx.ack.eq(self._tx_data.f.val.w_stb),
self._tx_rdy.f.val.r_data.eq(tx.rdy),
Expand All @@ -102,7 +104,7 @@ def elaborate(self, platform):
]

m.d.comb += [
rx.i.eq(self.pins.rx_i),
rx.i.eq(self.pins.rx.i),
rx.ack.eq(~rx_avail),
rx.divisor.eq(self._divisor.f.val.data),
self._rx_data.f.val.r_data.eq(rx_buf),
Expand Down
3 changes: 3 additions & 0 deletions amaranth_orchard/memory/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .hyperram import * # noqa
from .spimemio import * # noqa
from .sram import * # noqa
82 changes: 39 additions & 43 deletions amaranth_orchard/memory/hyperram.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,36 +10,33 @@
from amaranth.lib.wiring import In, Out, connect, flipped
from amaranth.utils import ceil_log2

from amaranth.sim import *
from amaranth.sim import Simulator

from amaranth_soc import csr, wishbone
from amaranth_soc.memory import MemoryMap

from chipflow_lib.platforms import BidirPinSignature, OutputPinSignature

__all__ = ["HyperRAMPins", "HyperRAM"]


class HyperRAMPins(wiring.PureInterface):
class Signature(wiring.Signature):
def __init__(self, *, cs_count=1):
self.cs_count = cs_count
super().__init__({
"clk_o": Out(1),
"csn_o": Out(cs_count),
"rstn_o": Out(1),
"rwds_o": Out(1),
"rwds_oe": Out(1),
"rwds_i": In(1),
"dq_o": Out(8),
"dq_oe": Out(8),
"dq_i": In(8),
"clk": Out(OutputPinSignature(1)),
"csn": Out(OutputPinSignature(cs_count)),
"rstn": Out(OutputPinSignature(1)),
"rwds": Out(BidirPinSignature(1)),
"dq": Out(BidirPinSignature(8)),
})

def create(self, *, path=(), src_loc_at=0):
return HyperRAMPins(cs_count=self.cs_count, src_loc_at=1 + src_loc_at)

def __init__(self, *, cs_count=1, path=(), src_loc_at=0):
super().__init__(self.Signature(cs_count=cs_count), path=path, src_loc_at=1 + src_loc_at)
self.cs_count = cs_count


class HyperRAM(wiring.Component):
Expand All @@ -62,14 +59,14 @@ class HRAMConfig(csr.Register, access="w"):
"""
def __init__(self, mem_name=("mem",), *, pins, init_latency=7):
self.pins = pins
self.cs_count = len(self.pins.csn_o)
self.cs_count = pins.cs_count
self.size = 2**23 * self.cs_count # 8MB per CS pin
self.init_latency = init_latency
assert self.init_latency in (6, 7) # TODO: anything else possible ?

regs = csr.Builder(addr_width=3, data_width=8)

self._ctrl_cfg = regs.add("ctrl_cfg", self.CtrlConfig(), offset=0x0)
self._ctrl_cfg = regs.add("ctrl_cfg", self.CtrlConfig(init_latency), offset=0x0)
self._hram_cfg = regs.add("hram_cfg", self.HRAMConfig(), offset=0x4)

self._bridge = csr.Bridge(regs.as_memory_map())
Expand All @@ -79,9 +76,9 @@ def __init__(self, mem_name=("mem",), *, pins, init_latency=7):
data_memory_map.add_resource(name=mem_name, size=self.size, resource=self)

super().__init__({
"ctrl_bus": csr.Signature(addr_width=regs.addr_width, data_width=regs.data_width),
"data_bus": wishbone.Signature(addr_width=ceil_log2(self.size / 4), data_width=32,
granularity=8),
"ctrl_bus": In(csr.Signature(addr_width=regs.addr_width, data_width=regs.data_width)),
"data_bus": In(wishbone.Signature(addr_width=ceil_log2(self.size >> 2), data_width=32,
granularity=8)),
})
self.ctrl_bus.memory_map = ctrl_memory_map
self.data_bus.memory_map = data_memory_map
Expand All @@ -90,7 +87,7 @@ def elaborate(self, platform):
m = Module()
m.submodules.bridge = self._bridge

connect(m, flipped(self.bus), self._bridge.bus)
connect(m, flipped(self.ctrl_bus), self._bridge.bus)

is_ctrl_write = Signal()
latched_adr = Signal(len(self.data_bus.adr))
Expand All @@ -104,7 +101,6 @@ def elaborate(self, platform):

# Data shift register
sr = Signal(48)
sr_shift = Signal()

# Whether or not we need to apply x2 latency
x2_lat = Signal()
Expand All @@ -123,27 +119,27 @@ def elaborate(self, platform):
m.d.sync += counter.eq(counter-1)
with m.If(counter.any()):
# move shift register (sample/output data) on posedge
m.d.sync += sr.eq(Cat(self.pins.dq_i, sr[:-8]))
m.d.sync += sr.eq(Cat(self.pins.dq.i, sr[:-8]))

m.d.comb += [
self.pins.clk_o.eq(clk),
self.pins.csn_o.eq(csn),
self.pins.rstn_o.eq(~ResetSignal()),
self.pins.dq_o.eq(sr[-8:]),
self.pins.clk.o.eq(clk),
self.pins.csn.o.eq(csn),
self.pins.rstn.o.eq(~ResetSignal()),
self.pins.dq.o.eq(sr[-8:]),
self.data_bus.dat_r.eq(sr[:32]),
]

with m.FSM() as fsm:
with m.State("IDLE"):
m.d.sync += [
counter.eq(0),
self.pins.rwds_oe.eq(0),
self.pins.rwds.oe.eq(0),
csn.eq((1 << self.cs_count) - 1), # all disabled
]
with m.If(self.data_bus.stb & self.data_bus.cyc): # data bus activity
m.d.sync += [
csn.eq(~(1 << (self.data_bus.adr[21:]))),
self.pins.dq_oe.eq(1),
self.pins.dq.oe.eq(1),
counter.eq(6),
# Assign CA
sr[47].eq(~self.data_bus.we), # R/W#
Expand All @@ -161,7 +157,7 @@ def elaborate(self, platform):
with m.If(self._hram_cfg.f.val.w_stb): # config register write
m.d.sync += [
csn.eq(~(1 << (self._hram_cfg.f.val.w_data[16:16+ceil_log2(self.cs_count)]))),
self.pins.dq_oe.eq(1),
self.pins.dq.oe.eq(1),
counter.eq(6),
# Assign CA
sr[47].eq(0), # R/W#
Expand All @@ -181,7 +177,7 @@ def elaborate(self, platform):
with m.If(counter == 3):
# RWDS tells us if we need 2x latency or not
# sample at an arbitrary midpoint in CA
m.d.sync += x2_lat.eq(self.pins.rwds_i)
m.d.sync += x2_lat.eq(self.pins.rwds.i)
with m.If(counter == 1):
# (almost) done shifting CA
with m.If(is_ctrl_write):
Expand All @@ -199,31 +195,31 @@ def elaborate(self, platform):
m.d.sync += counter.eq(2*self._ctrl_cfg.f.latency.data - 2)
m.next = "WAIT_LAT"
with m.State("WAIT_LAT"):
m.d.sync += self.pins.dq_oe.eq(0)
m.d.sync += self.pins.dq.oe.eq(0)
with m.If(counter == 1):
# About to shift data
m.d.sync += [
sr[:16].eq(0),
sr[16:].eq(self.data_bus.dat_w),
self.pins.dq_oe.eq(self.data_bus.we),
self.pins.rwds_oe.eq(self.data_bus.we),
self.pins.rwds_o.eq(~self.data_bus.sel[3]),
self.pins.dq.oe.eq(self.data_bus.we),
self.pins.rwds.oe.eq(self.data_bus.we),
self.pins.rwds.o.eq(~self.data_bus.sel[3]),
counter.eq(4),
]
m.next = "SHIFT_DAT"
with m.State("SHIFT_DAT"):
with m.If(counter == 4):
m.d.sync += self.pins.rwds_o.eq(~self.data_bus.sel[2])
m.d.sync += self.pins.rwds.o.eq(~self.data_bus.sel[2])
with m.If(counter == 3):
m.d.sync += self.pins.rwds_o.eq(~self.data_bus.sel[1])
m.d.sync += self.pins.rwds.o.eq(~self.data_bus.sel[1])
with m.If(counter == 2):
m.d.sync += self.pins.rwds_o.eq(~self.data_bus.sel[0])
m.d.sync += self.pins.rwds.o.eq(~self.data_bus.sel[0])
with m.If(counter == 1):
m.next = "ACK_XFER"
with m.State("ACK_XFER"):
m.d.sync += [
self.pins.rwds_oe.eq(0),
self.pins.dq_oe.eq(0),
self.pins.rwds.oe.eq(0),
self.pins.dq.oe.eq(0),
self.data_bus.ack.eq(1),
wait_count.eq(9)
]
Expand All @@ -239,9 +235,9 @@ def elaborate(self, platform):
m.d.sync += [
sr[:16].eq(0),
sr[16:].eq(self.data_bus.dat_w),
self.pins.dq_oe.eq(self.data_bus.we),
self.pins.rwds_oe.eq(self.data_bus.we),
self.pins.rwds_o.eq(~self.data_bus.sel[3]),
self.pins.dq.oe.eq(self.data_bus.we),
self.pins.rwds.oe.eq(self.data_bus.we),
self.pins.rwds.o.eq(~self.data_bus.sel[3]),
latched_adr.eq(self.data_bus.adr),
counter.eq(4),
]
Expand All @@ -258,8 +254,8 @@ def elaborate(self, platform):
m.next = "CTRL_DONE"
with m.State("CTRL_DONE"):
m.d.sync += [
self.pins.rwds_oe.eq(0),
self.pins.dq_oe.eq(0),
self.pins.rwds.oe.eq(0),
self.pins.dq.oe.eq(0),
csn.eq((1 << self.cs_count) - 1),
]
m.next = "IDLE"
Expand Down Expand Up @@ -288,8 +284,8 @@ def process():
yield hram.data_bus.we.eq(0)
yield hram.data_bus.stb.eq(1)
yield hram.data_bus.cyc.eq(1)
yield pins.rwds_i.eq(1)
yield pins.dq_i.eq(0xFF)
yield pins.rwds.i.eq(1)
yield pins.dq.i.eq(0xFF)
for i in range(100):
if (yield hram.data_bus.ack):
yield hram.data_bus.stb.eq(0)
Expand Down
Loading