diff --git a/core/src/main/scala-2/chisel3/OneHotEnumMacros.scala b/core/src/main/scala-2/chisel3/OneHotEnumMacros.scala new file mode 100644 index 00000000000..3320a27f741 --- /dev/null +++ b/core/src/main/scala-2/chisel3/OneHotEnumMacros.scala @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: Apache-2.0 + +package chisel3 + +import scala.reflect.macros.blackbox.Context +import scala.language.experimental.macros + +private[chisel3] trait OneHotEnumIntf extends ChiselEnumIntf { self: OneHotEnum => + override def Value: Type = macro OneHotEnumMacros.ValImpl + override def Value(id: UInt): Type = macro OneHotEnumMacros.ValCustomImpl +} + +private[chisel3] object OneHotEnumMacros { + def ValImpl(c: Context): c.Tree = { + import c.universe._ + + val term = c.internal.enclosingOwner + val name = term.name.decodedName.toString.trim + + if (name.contains(" ")) { + c.abort(c.enclosingPosition, "Value cannot be called without assigning to an enum") + } + + q"""this.do_OHValue($name)""" + } + + def ValCustomImpl(c: Context)(id: c.Expr[UInt]): c.universe.Tree = { + c.abort(c.enclosingPosition, "OneHotEnum does not support custom values") + } +} diff --git a/core/src/main/scala-3/chisel3/ChiselEnumIntf.scala b/core/src/main/scala-3/chisel3/ChiselEnumIntf.scala index c66ff39d987..eb573baa312 100644 --- a/core/src/main/scala-3/chisel3/ChiselEnumIntf.scala +++ b/core/src/main/scala-3/chisel3/ChiselEnumIntf.scala @@ -17,3 +17,7 @@ private[chisel3] trait EnumTypeIntf { self: EnumType => private[chisel3] trait ChiselEnumIntf { self: ChiselEnum => // TODO macros } + +private[chisel3] trait OneHotEnumIntf extends ChiselEnumIntf { self: OneHotEnum => + // TODO macros +} diff --git a/core/src/main/scala/chisel3/ChiselEnum.scala b/core/src/main/scala/chisel3/ChiselEnum.scala index f2cbc684072..535ecba389d 100644 --- a/core/src/main/scala/chisel3/ChiselEnum.scala +++ b/core/src/main/scala/chisel3/ChiselEnum.scala @@ -92,47 +92,11 @@ abstract class EnumType(private[chisel3] val factory: ChiselEnum) extends Elemen if (litOption.isDefined) { true.B } else { - if (factory.isTotal) true.B else factory.all.map(this === _).reduce(_ || _) + if (factory.isTotal) true.B else factory._isValid(this) } } - /** Test if this enumeration is equal to any of the values in a given sequence - * - * @param s a [[scala.collection.Seq$ Seq]] of enumeration values to look for - * @return a hardware [[Bool]] that indicates if this value matches any of the given values - */ - final def isOneOf(s: Seq[EnumType])(implicit sourceInfo: SourceInfo): Bool = { - VecInit(s.map(this === _)).asUInt.orR - } - - /** Test if this enumeration is equal to any of the values given as arguments - * - * @param u1 the first value to look for - * @param u2 zero or more additional values to look for - * @return a hardware [[Bool]] that indicates if this value matches any of the given values - */ - final def isOneOf( - u1: EnumType, - u2: EnumType* - )( - implicit sourceInfo: SourceInfo - ): Bool = isOneOf(u1 +: u2.toSeq) - - def next(implicit sourceInfo: SourceInfo): this.type = { - if (litOption.isDefined) { - val index = factory.all.indexOf(this) - - if (index < factory.all.length - 1) { - factory.all(index + 1).asInstanceOf[this.type] - } else { - factory.all.head.asInstanceOf[this.type] - } - } else { - val enums_with_nexts = factory.all.zip(factory.all.tail :+ factory.all.head) - val next_enum = SeqUtils.priorityMux(enums_with_nexts.map { case (e, n) => (this === e, n) }) - next_enum.asInstanceOf[this.type] - } - } + def next(implicit sourceInfo: SourceInfo): this.type = factory._next(this) private[chisel3] def bindToLiteral(num: BigInt, w: Width): Unit = { val lit = ULit(num, w) @@ -203,7 +167,59 @@ abstract class EnumType(private[chisel3] val factory: ChiselEnum) extends Elemen } abstract class ChiselEnum extends ChiselEnumIntf { - class Type extends EnumType(this) + private[chisel3] protected def _valIs(v: Type, lit: Type)(implicit sourceInfo: SourceInfo): Bool = { + v === lit + } + + private[chisel3] protected def _valIsOneOf(v: Type, s: Seq[Type])(implicit sourceInfo: SourceInfo): Bool = + VecInit(s.map(_valIs(v, _))).asUInt.orR + + private[chisel3] protected def _isValid(v: EnumType)(implicit sourceInfo: SourceInfo): Bool = { + assert(v.isInstanceOf[Type]) + all.map(v === _).reduce(_ || _) + } + + private[chisel3] protected def _next(v: EnumType)(implicit sourceInfo: SourceInfo): v.type = { + assert(v.isInstanceOf[Type]) + + if (v.litOption.isDefined) { + val index = v.factory.all.indexOf(v) + + if (index < v.factory.all.length - 1) { + v.factory.all(index + 1).asInstanceOf[v.type] + } else { + v.factory.all.head.asInstanceOf[v.type] + } + } else { + val enums_with_nexts = v.factory.all.zip(v.factory.all.tail :+ v.factory.all.head) + val next_enum = SeqUtils.priorityMux(enums_with_nexts.map { case (e, n) => (v === e, n) }) + next_enum.asInstanceOf[v.type] + } + } + + class Type extends EnumType(this) { + + /** Test if this enumeration is equal to any of the values in a given sequence + * + * @param s a [[scala.collection.Seq$ Seq]] of enumeration values to look for + * @return a hardware [[Bool]] that indicates if this value matches any of the given values + */ + final def isOneOf(s: Seq[Type])(implicit sourceInfo: SourceInfo): Bool = _valIsOneOf(this, s) + + /** Test if this enumeration is equal to any of the values given as arguments + * + * @param u1 the first value to look for + * @param u2 zero or more additional values to look for + * @return a hardware [[Bool]] that indicates if this value matches any of the given values + */ + final def isOneOf( + u1: Type, + u2: Type* + )( + implicit sourceInfo: SourceInfo + ): Bool = isOneOf(u1 +: u2.toSeq) + } + object Type { def apply(): Type = ChiselEnum.this.apply() } diff --git a/core/src/main/scala/chisel3/OneHotEnum.scala b/core/src/main/scala/chisel3/OneHotEnum.scala new file mode 100644 index 00000000000..e57d9a8cd7c --- /dev/null +++ b/core/src/main/scala/chisel3/OneHotEnum.scala @@ -0,0 +1,89 @@ +// SPDX-License-Identifier: Apache-2.0 + +package chisel3 + +import chisel3.experimental.SourceInfo + +abstract class OneHotEnum extends ChiselEnum with OneHotEnumIntf { + private var next1Pos = 0 + + // copied from chisel3.util + private def isPow2(in: BigInt): Boolean = in > 0 && ((in & (in - 1)) == 0) + private def log2Ceil(in: BigInt): Int = (in - 1).bitLength + + override def _valIs(v: Type, lit: Type)(implicit sourceInfo: SourceInfo): Bool = { + require(lit.isLit, "Can only compare against literal values") + val yy = lit.litValue + require(isPow2(yy), s"Can only compare against one-hot values, got $yy (0b${yy.toString(2)})") + v.asUInt.apply(log2Ceil(yy)) + } + + override def _isValid(v: EnumType)(implicit sourceInfo: SourceInfo): Bool = { + assert(v.isInstanceOf[Type]) + assert(v.getWidth == all.length, s"OneHotEnum ${this} has ${all.length} values, but ${v} has width ${v.getWidth}") + val x = v.asUInt + x.orR && ((x & (x - 1.U)) === 0.U) + } + + override def _next(v: EnumType)(implicit sourceInfo: SourceInfo): v.type = { + assert(v.isInstanceOf[Type]) + + if (v.litOption.isDefined) { + val index = v.factory.all.indexOf(v) + + if (index < v.factory.all.length - 1) { + v.factory.all(index + 1).asInstanceOf[v.type] + } else { + v.factory.all.head.asInstanceOf[v.type] + } + } else { + safe(v.asUInt.rotateLeft(1))._1.asInstanceOf[v.type] + } + } + + override def isTotal: Boolean = false + + // TODO: Is there a cleaner way? + final implicit class OneHotType(value: Type) extends Type { + override def isLit: Boolean = value.isLit + + override def litValue: BigInt = value.litValue + + final def is(other: Type)(implicit sourceInfo: SourceInfo): Bool = _valIs(value, other) + + /** + * Multiplexer that selects between multiple values based on this one-hot enum. + * + * @param choices a sequence of tuples of (enum value, output when matched) + * @return the output corresponding to the matched enum value + * @note the output is undefined if none of the values match + */ + final def select[T <: Data](choices: Seq[(Type, T)])(implicit sourceInfo: SourceInfo): T = { + require(choices.nonEmpty, "select must be passed a non-empty list of choices") + // FIXME: this is a workaround to suppress a superfluous cast warning emitted by [[SeqUtils.oneHotMux]] when T is of the same EnumType. Unfortunately, it also hides potential cast warnings from the inner expressions. + suppressEnumCastWarning { + SeqUtils.oneHotMux(choices.map { case (oh, t) => is(oh) -> t }) + } + } + + /** + * Multiplexer that selects between multiple values based on this one-hot enum. + * + * @param firstChoice a tuple of (enum value, output when matched) + * @param otherChoices a varargs list of tuples of (enum value, output when matched) + * @return the output corresponding to the matched enum value + * @note if none of the enum values match, the output is undefined + */ + final def select[T <: Data]( + firstChoice: (Type, T), + otherChoices: (Type, T)* + )(implicit sourceInfo: SourceInfo): T = select(firstChoice +: otherChoices) + + } + + def do_OHValue(name: String): Type = { + val value = super.do_Value(name, BigInt(2).pow(next1Pos).U) + next1Pos += 1 + value + } +} diff --git a/src/test/scala-2/chiselTests/OneHotEnumSpec.scala b/src/test/scala-2/chiselTests/OneHotEnumSpec.scala new file mode 100644 index 00000000000..96056585a76 --- /dev/null +++ b/src/test/scala-2/chiselTests/OneHotEnumSpec.scala @@ -0,0 +1,372 @@ +// SPDX-License-Identifier: Apache-2.0 + +package chiselTests + +import chisel3._ +import chisel3.util.{is, switch, Counter} +import chisel3.simulator.scalatest.ChiselSim +import chisel3.simulator.stimulus.RunUntilFinished +import chisel3.testing.scalatest.FileCheck +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers +import chisel3.util.Decoupled +import chisel3.util.isPow2 + +object OneHotEnumExample extends OneHotEnum { + val A, B, C, D, E = Value +} + +object OtherOneHotEnum extends OneHotEnum { + val W, X, Y, Z = Value +} + +class OneHotEnumSafeCast extends Module { + val io = IO(new Bundle { + val in = Input(UInt(OneHotEnumExample.getWidth.W)) + val out = Output(OneHotEnumExample()) + val valid = Output(Bool()) + }) + + val (enumVal, valid) = OneHotEnumExample.safe(io.in) + io.out :#= enumVal + io.valid := valid +} + +class OneHotEnumSafeCastTester extends Module { + for ((enumVal, i) <- OneHotEnumExample.all.zipWithIndex) { + val lit = (1 << i).U(OneHotEnumExample.getWidth.W) + val mod = Module(new OneHotEnumSafeCast) + mod.io.in :#= lit + assert(mod.io.out === enumVal) + assert(mod.io.valid === true.B) + } + + val invalid_values = + (1 until (1 << OneHotEnumExample.getWidth)).filter(!isPow2(_)).map(_.U) + + for (invalid_val <- invalid_values) { + val mod = Module(new SafeCastFromNonLit) + mod.io.in := invalid_val + + assert(mod.io.valid === false.B) + } + + stop() +} + +class OneHotEnumFSM extends Module { + + object State extends OneHotEnum { + val Idle, One, Two, Three = Value + } + + object State2 extends ChiselEnum { + val Idle, One, Two, Three = Value + } + + val io = IO(new Bundle { + val in = Input(UInt(8.W)) + val out = Output(UInt(8.W)) + val out2 = Output(UInt(8.W)) + val other_out = Output(UInt(5.W)) + val state = Output(UInt(4.W)) + }) + + import State._ + + val state = RegInit(Idle) + val state2 = RegInit(State2.Idle) + + assert(state.getWidth == State.all.length) + + assert(state.isValid) + + assert((1.U << state2.asUInt) === state.asUInt) + + io.state :#= state.asUInt + io.out2 := DontCare + + printf(cf"state is $state (${state.asUInt}%b)\n") + + when(state2 === State2.Idle) { + assert(state.is(State.Idle)) + assert(state.next.is(State.One)) + assert(state.next === State.One) + assert(state2.next === State2.One) + io.out2 := 0.U + state2 :#= State2.One + } + + when(state2 === State2.One) { + io.out2 := 1.U + state2 :#= State2.Two + + assert(state.next.is(State.Two)) + assert(state2.next === State2.Two) + } + + when(state2.isOneOf(State2.Two)) { + io.out2 := 2.U + state2 :#= State2.Three + assert(state.next.is(State.Three)) + assert(state2.next === State2.Three) + } + + when(state2.isOneOf(State2.Three)) { + io.out2 :#= io.in + state2 :#= State2.Idle + assert(state.next.is(State.Idle)) + assert(state2.next === State2.Idle) + } + + assert(state.isOneOf(State.all)) + + for (s <- State.all) { + assert(state.is(s) === (state === s)) + assert(state.is(s) === state.isOneOf(s)) + } + + assert(State.Idle.next === State.One) + assert(State.One.next === State.Two) + assert(State.Two.next === State.Three) + assert(State.Three.next === State.Idle) + + state :#= state.select( + Idle -> One, + One -> Two, + Two -> Three, + Three -> Idle + ) + + io.out :#= state.select( + Idle -> 0x00.U, + One -> 0x01.U, + Two -> 0x02.U, + Three -> io.in + ) + + assert(io.out === io.out2) + + io.other_out :#= state.select( + Idle -> 0x10.U, + One -> 0x11.U, + Two -> 0x12.U, + Three -> 0x13.U + ) +} + +class OneHotEnumFSMTester extends Module { + val mod = Module(new OneHotEnumFSM) + val counter = Counter(9) + + val expectedStateIndex = counter.value % mod.State.all.length.U + + mod.io.in := counter.value + + assert(mod.io.state === (1.U << expectedStateIndex)) + assert(mod.io.other_out === (0x10.U + expectedStateIndex)) + + switch(expectedStateIndex) { + is(0.U) { + assert(mod.io.out2 === 0.U) + } + is(1.U) { + assert(mod.io.out2 === 1.U) + } + is(2.U) { + assert(mod.io.out2 === 2.U) + } + is(3.U) { + assert(mod.io.out2 === counter.value) + } + } + + when(counter.inc()) { + stop() + } +} + +class OneHotEnumSequenceDetector extends Module { + + object State extends OneHotEnum { + val Idle, Saw1, Saw10, Saw101 = Value + } + + val io = IO(new Bundle { + val in = Input(Bool()) + val detect = Output(Bool()) + val state = Output(UInt(State.all.length.W)) + }) + + import State._ + + val state = RegInit(Idle) + + assert(state.getWidth == State.all.length) + assert(state.isValid) + + assert(state.isOneOf(State.all)) + + for (s <- State.all) { + assert(state.is(s) === (state === s)) + assert(state.is(s) === state.isOneOf(s)) + } + + assert(State.Idle.next === State.Saw1) + assert(State.Saw1.next === State.Saw10) + assert(State.Saw10.next === State.Saw101) + assert(State.Saw101.next === State.Idle) + + io.detect := state.is(Saw101) + + state :#= state.select( + Idle -> Mux(io.in, Saw1, Idle), + Saw1 -> Mux(io.in, Saw1, Saw10), + Saw10 -> Mux(io.in, Saw101, Idle), + Saw101 -> Mux(io.in, Saw1, Saw10) + ) + + io.state :#= state.asUInt + + printf(cf"state is $state (${state.asUInt}%b), in is ${io.in}, detect is ${io.detect}\n") +} + +class OneHotEnumSequenceDetectorTester extends Module { + val mod = Module(new OneHotEnumSequenceDetector) + + import mod.State._ + + val symbols = VecInit(Seq(1, 0, 1, 0, 1, 1, 0, 1, 0, 0).map(_.B)) + val expectedHits = VecInit(Seq(0, 0, 0, 1, 0, 1, 0, 0, 1, 0).map(_.B)) + + val expectedStates = VecInit( + Seq(Idle, Saw1, Saw10, Saw101, Saw10, Saw101, Saw1, Saw10, Saw101, Saw10, Idle) + ) + + val counter = Counter(symbols.length) + + mod.io.in := symbols(counter.value) + + assert( + mod.io.detect === expectedHits(counter.value), + cf"mismatch at ${counter.value}: got ${mod.io.detect}, expected ${expectedHits(counter.value)}" + ) + assert(mod.io.state === expectedStates(counter.value).asUInt) + + when(counter.inc()) { + stop() + } +} + +object VendingMachineState extends OneHotEnum { + val Idle, Have5, Have10, Vend = Value +} + +class OneHotEnumVendingMachine extends Module { + + val io = IO(new Bundle { + val coin = Flipped(Decoupled(UInt(2.W))) // 0: none, 1: nickel, 2: dime + val vend = Output(Bool()) + val change = Output(Bool()) + val state = Output(VendingMachineState()) + }) + + import VendingMachineState._ + + val state = RegInit(Idle) + + assert(state.getWidth == VendingMachineState.all.length) + assert(state.isValid) + assert(state.isOneOf(VendingMachineState.all)) + for (s <- VendingMachineState.all) { + assert(state.is(s) === (state === s)) + assert(state.is(s) === state.isOneOf(s)) + } + + val change = RegInit(false.B) + + val insertFive = io.coin.valid && io.coin.bits === 1.U + val insertTen = io.coin.valid && io.coin.bits === 2.U + + io.coin.ready := state.isOneOf(Idle, Have5, Have10) + + val nextState = state.select( + Idle -> Mux(insertTen, Have10, Mux(insertFive, Have5, Idle)), + Have5 -> Mux(insertTen, Vend, Mux(insertFive, Have10, Have5)), + Have10 -> Mux(insertFive || insertTen, Vend, Have10), + Vend -> Idle + ) + + state :#= nextState + change := state.is(Have10) && insertTen + + io.state :#= state + io.vend := state.is(Vend) + io.change := change +} + +class OneHotEnumVendingMachineTester extends Module { + val mod = Module(new OneHotEnumVendingMachine) + + import VendingMachineState._ + + val coins = VecInit(Seq(0, 1, 0, 2, 0, 0, 2, 1, 0, 2, 2, 0).map(_.U(2.W))) + val expectedStates = VecInit( + Seq(Idle, Idle, Have5, Have5, Vend, Idle, Idle, Have10, Vend, Idle, Have10, Vend) + ) + val expectedVend = VecInit(Seq(0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1).map(_.B)) + val expectedChange = VecInit(Seq(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1).map(_.B)) + + val counter = Counter(coins.length) + + mod.io.coin.bits :#= coins(counter.value) + mod.io.coin.valid := true.B + + assert(mod.io.state.isOneOf(Idle, Have5, Have10) === mod.io.coin.ready) + + assert(mod.io.state === expectedStates(counter.value)) + assert(mod.io.vend === expectedVend(counter.value)) + assert(mod.io.change === expectedChange(counter.value)) + + when(counter.inc()) { + stop() + } +} + +class OneHotEnumSpec extends AnyFlatSpec with Matchers with LogUtils with ChiselSim with FileCheck { + behavior of "OneHotEnum" + + it should "maintain Scala-level type-safety" in { + def foo(e: OneHotEnumExample.Type): Unit = {} + + "foo(OneHotEnumExample.A); foo(OneHotEnumExample.A.next); foo(OneHotEnumExample.E.next)" should compile + "foo(OtherOneHotEnum.otherEnum)" shouldNot compile + "foo(EnumExample.otherEnum)" shouldNot compile + "foo(OtherEnum.otherEnum)" shouldNot compile + } + + it should "prevent enums from being declared without names" in { + "object UnnamedEnum1 extends OneHotEnum { Value }" shouldNot compile + } + + it should "prevent enums from being declared with custom values" in { + "object UnnamedEnum2 extends OneHotEnum { A = Value(1.U) }" shouldNot compile + } + + it should "safely cast non-literal UInts to enums correctly and detect illegal casts" in { + simulate(new OneHotEnumSafeCastTester)(RunUntilFinished(3)) + } + + "OneHotEnumFSM" should "work" in { + simulate(new OneHotEnumFSMTester)(RunUntilFinished(10)) + } + + "OneHotEnumSequenceDetector" should "work" in { + simulate(new OneHotEnumSequenceDetectorTester)(RunUntilFinished(12)) + } + + "OneHotEnumVendingMachine" should "work" in { + simulate(new OneHotEnumVendingMachineTester)(RunUntilFinished(16)) + } + +}