Skip to content

Commit ada1bcc

Browse files
committed
Harden Period creation and use of Run/PhaseId
1 parent 122fd6a commit ada1bcc

File tree

11 files changed

+89
-80
lines changed

11 files changed

+89
-80
lines changed

compiler/src/dotty/tools/dotc/Compiler.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -153,11 +153,11 @@ class Compiler {
153153
List(new GenBCode) :: // Generate JVM bytecode
154154
Nil
155155

156-
// TODO: Initially 0, so that the first nextRunId call would return InitialRunId == 1
157-
// Changing the initial runId from 1 to 0 makes the scala2-library-bootstrap fail to compile,
156+
// TODO: Initially InitialRunId - 1, so that the first nextRunId call would return InitialRunId
157+
// Setting the initial runId to InitialRunId - 1 makes the scala2-library-bootstrap fail to compile,
158158
// when the underlying issue is fixed, please update dotc.profiler.RealProfiler.chromeTrace logic
159-
private var runId: Int = 1
160-
def nextRunId: Int = {
159+
private var runId: Periods.RunId = Periods.InitialRunId
160+
def nextRunId: Periods.RunId = {
161161
runId += 1; runId
162162
}
163163

compiler/src/dotty/tools/dotc/Run.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -579,11 +579,13 @@ extends ImplicitRunInfo, ConstraintRunInfo, cc.CaptureRunInfo {
579579
start.setRun(this: @unchecked)
580580
}
581581

582-
private var myCtx: Context | Null = rootContext(using ictx)
582+
private var myCtx: Context | Null = null
583583

584584
/** The context created for this run */
585-
given runContext[Dummy_so_its_a_def]: Context = myCtx.nn
586-
assert(runContext.runId <= Periods.MaxPossibleRunId)
585+
given runContext[Dummy_so_its_a_def]: Context =
586+
if myCtx eq null then myCtx = rootContext(using ictx)
587+
assert(myCtx.nn.runId <= Periods.MaxPossibleRunId)
588+
myCtx.nn
587589
}
588590

589591
object Run {

compiler/src/dotty/tools/dotc/cc/Capability.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import util.{SimpleIdentitySet, EqHashMap}
88
import util.common.alwaysTrue
99
import scala.collection.mutable
1010
import CCState.*
11-
import Periods.{NoRunId, RunWidth}
11+
import Periods.{NoRunId, RunId, RunWidth}
1212
import compiletime.uninitialized
1313
import StdNames.nme
1414
import CaptureSet.{Refs, emptyRefs, VarState}
@@ -48,7 +48,7 @@ import collection.immutable
4848
*/
4949
object Capabilities:
5050
opaque type Validity = Int
51-
def validId(runId: Int, iterId: Int): Validity =
51+
def validId(runId: RunId, iterId: Int): Validity =
5252
runId + (iterId << RunWidth)
5353
def currentId(using Context): Validity = validId(ctx.runId, ccState.iterationId)
5454
val invalid: Validity = validId(NoRunId, 0)

compiler/src/dotty/tools/dotc/core/Contexts.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1055,7 +1055,7 @@ object Contexts {
10551055
private[core] var fusedPhases: Array[Phase] = Array.empty[Phase]
10561056

10571057
/** Next denotation transformer id */
1058-
private[core] var nextDenotTransformerId: Array[Int] = uninitialized
1058+
private[core] var nextDenotTransformerId: Array[Periods.PhaseId] = uninitialized
10591059

10601060
private[core] var denotTransformers: Array[DenotTransformer] = uninitialized
10611061

compiler/src/dotty/tools/dotc/core/Denotations.scala

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -681,7 +681,7 @@ object Denotations {
681681
* There may be several `SingleDenotation`s with different validity
682682
* representing the same underlying definition at different phases.
683683
* These are called a "flock". Flock members are generated by
684-
* @See current. Flock members are connected in a ring
684+
* `current`. Flock members are connected in a ring
685685
* with their `nextInRun` fields.
686686
*
687687
* There are the following invariants concerning flock members
@@ -696,10 +696,10 @@ object Denotations {
696696
* of this run.
697697
*/
698698
def initial: SingleDenotation =
699-
if (validFor.firstPhaseId <= 1) this
699+
if (validFor.firstPhaseId == FirstPhaseId) this
700700
else {
701701
var current = nextInRun
702-
while (current.validFor.code > this.validFor.code) current = current.nextInRun
702+
while (current.validFor > this.validFor) current = current.nextInRun
703703
current
704704
}
705705

@@ -784,7 +784,7 @@ object Denotations {
784784
* are otherwise undefined.
785785
*/
786786
def skipRemoved(using Context): SingleDenotation =
787-
if (validFor.code <= 0) nextDefined else this
787+
if (validFor == Nowhere) nextDefined else this
788788

789789
/** Produce a denotation that is valid for the given context.
790790
* Usually called when !(validFor contains ctx.period)
@@ -816,10 +816,10 @@ object Denotations {
816816
var cur = this
817817
// search for containing period as long as nextInRun increases.
818818
var next = nextInRun
819-
while next.validFor.code > valid.code && !next.validFor.contains(currentPeriod) do
819+
while next.validFor > valid && !next.validFor.contains(currentPeriod) do
820820
cur = next
821821
next = next.nextInRun
822-
if next.validFor.code > valid.code then
822+
if next.validFor > valid then
823823
// in this case, next.validFor contains currentPeriod
824824
cur = next
825825
cur
@@ -875,14 +875,14 @@ object Denotations {
875875
cur
876876
end goBack
877877

878-
if valid.code <= 0 then
878+
if valid == Nowhere then
879879
// can happen if we sit on a stale denotation which has been replaced
880880
// wholesale by an installAfter; in this case, proceed to the next
881881
// denotation and try again.
882882
nextDefined
883883
else if valid.runId != currentPeriod.runId then
884884
toNewRun
885-
else if currentPeriod.code > valid.code then
885+
else if currentPeriod > valid then
886886
goForward
887887
else
888888
goBack
@@ -923,7 +923,7 @@ object Denotations {
923923
*/
924924
protected def transformAfter(phase: DenotTransformer, f: SymDenotation => SymDenotation)(using Context): Unit = {
925925
var current = symbol.current
926-
while (current.validFor.firstPhaseId < phase.id && (current.nextInRun.validFor.code > current.validFor.code))
926+
while (current.validFor.firstPhaseId < phase.id && (current.nextInRun.validFor > current.validFor))
927927
current = current.nextInRun
928928
var hasNext = true
929929
while ((current.validFor.firstPhaseId >= phase.id) && hasNext) {
@@ -932,7 +932,7 @@ object Denotations {
932932
current1.validFor = current.validFor
933933
current.replaceWith(current1)
934934
}
935-
hasNext = current1.nextInRun.validFor.code > current1.validFor.code
935+
hasNext = current1.nextInRun.validFor > current1.validFor
936936
current = current1.nextInRun
937937
}
938938
}

compiler/src/dotty/tools/dotc/core/Periods.scala

Lines changed: 31 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ object Periods {
2424
/** Are all base types in the current period guaranteed to be the same as in period `p`? */
2525
def currentHasSameBaseTypesAs(p: Period)(using Context): Boolean =
2626
val period = ctx.period
27-
period.code == p.code ||
27+
period == p ||
2828
period.runId == p.runId &&
2929
unfusedPhases(period.phaseId).sameBaseTypesStartId ==
3030
unfusedPhases(p.phaseId).sameBaseTypesStartId
@@ -37,9 +37,8 @@ object Periods {
3737
* last phase id: 7 bits
3838
* #phases before last: 7 bits
3939
*
40-
* // Dmitry: sign == 0 isn't actually always true, in some cases phaseId == -1 is used for shifts, that easily creates code < 0
4140
*/
42-
class Period(val code: Int) extends AnyVal with Showable {
41+
class Period private[Periods] (private val code: Int) extends AnyVal with Showable {
4342

4443
/** The run identifier of this period. */
4544
def runId: RunId = code >>> (PhaseWidth * 2)
@@ -52,7 +51,7 @@ object Periods {
5251
(code >>> PhaseWidth) & PhaseMask
5352

5453
/** The first phase of this period */
55-
def firstPhaseId: Int = lastPhaseId - (code & PhaseMask)
54+
def firstPhaseId: PhaseId = lastPhaseId - (code & PhaseMask)
5655

5756
def containsPhaseId(id: PhaseId): Boolean = firstPhaseId <= id && id <= lastPhaseId
5857

@@ -76,7 +75,7 @@ object Periods {
7675
// iff r1 == r2 & l1 >= l2 && l1 - d1 <= l2 - d2
7776
// q.e.d
7877
val lastDiff = (code - that.code) >>> PhaseWidth
79-
lastDiff + (that.code & PhaseMask ) <= (this.code & PhaseMask)
78+
lastDiff + (that.code & PhaseMask) <= (this.code & PhaseMask)
8079
}
8180

8281
/** Does this period overlap with given period? */
@@ -101,25 +100,31 @@ object Periods {
101100
this.firstPhaseId min that.firstPhaseId,
102101
this.lastPhaseId max that.lastPhaseId)
103102

103+
inline def <(that: Period): Boolean =
104+
this.code < that.code
105+
106+
inline def >(that: Period): Boolean =
107+
this.code > that.code
108+
104109
def toText(p: Printer): Text =
105110
inContext(p.printerContext):
106111
this match
107-
case Nowhere => "Nowhere"
108-
case InitialPeriod => "InitialPeriod"
109-
case InvalidPeriod => "InvalidPeriod"
110-
case Period(NoRunId, 0, PhaseMask) => s"Period(NoRunId.all)"
111-
case Period(runId, 0, PhaseMask) => s"Period($runId.all)"
112-
case Period(runId, p1, pn) if p1 == pn => s"Period($runId.$p1(${ctx.base.phases(p1)}))"
113-
case Period(runId, p1, pn) => s"Period($runId.$p1(${ctx.base.phases(p1)})-$pn(${ctx.base.phases(pn)}))"
112+
case Nowhere => "Nowhere"
113+
case InitialPeriod => "InitialPeriod"
114+
case InvalidPeriod => "InvalidPeriod"
115+
case Period(NoRunId, FirstPhaseId, MaxPossiblePhaseId) => s"Period(NoRunId.all)"
116+
case Period(runId, FirstPhaseId, MaxPossiblePhaseId) => s"Period($runId.all)"
117+
case Period(runId, p1, pn) if p1 == pn => s"Period($runId.$p1(${ctx.base.phases(p1)}))"
118+
case Period(runId, p1, pn) => s"Period($runId.$p1(${ctx.base.phases(p1)})-$pn(${ctx.base.phases(pn)}))"
114119

115120
override def toString: String = this match
116-
case Nowhere => "Nowhere"
117-
case InitialPeriod => "InitialPeriod"
118-
case InvalidPeriod => "InvalidPeriod"
119-
case Period(NoRunId, 0, PhaseMask) => s"Period(NoRunId.all)"
120-
case Period(runId, 0, PhaseMask) => s"Period($runId.all)"
121-
case Period(runId, p1, pn) if p1 == pn => s"Period($runId.$p1)"
122-
case Period(runId, p1, pn) => s"Period($runId.$p1-$pn)"
121+
case Nowhere => "Nowhere"
122+
case InitialPeriod => "InitialPeriod"
123+
case InvalidPeriod => "InvalidPeriod"
124+
case Period(NoRunId, FirstPhaseId, MaxPossiblePhaseId) => s"Period(NoRunId.all)"
125+
case Period(runId, FirstPhaseId, MaxPossiblePhaseId) => s"Period($runId.all)"
126+
case Period(runId, p1, pn) if p1 == pn => s"Period($runId.$p1)"
127+
case Period(runId, p1, pn) => s"Period($runId.$p1-$pn)"
123128

124129
def ==(that: Period): Boolean = this.code == that.code
125130
def !=(that: Period): Boolean = this.code != that.code
@@ -137,7 +142,7 @@ object Periods {
137142

138143
/** The interval consisting of all periods of given run id */
139144
def allInRun(rid: RunId): Period =
140-
apply(rid, 0, PhaseMask)
145+
apply(rid, FirstPhaseId, MaxPossiblePhaseId)
141146

142147
def unapply(p: Period): Extractor = new Extractor(p.code)
143148

@@ -151,13 +156,6 @@ object Periods {
151156
}
152157
}
153158

154-
inline val NowhereCode = 0
155-
final val Nowhere: Period = new Period(NowhereCode)
156-
157-
final val InitialPeriod: Period = Period(InitialRunId, FirstPhaseId)
158-
159-
final val InvalidPeriod: Period = Period(NoRunId, NoPhaseId)
160-
161159
/** An ordinal number for compiler runs. First run has number 1. */
162160
type RunId = Int
163161
inline val NoRunId = 0
@@ -172,6 +170,11 @@ object Periods {
172170

173171
/** The number of bits needed to encode a phase identifier. */
174172
inline val PhaseWidth = 7
175-
inline val PhaseMask = (1 << PhaseWidth) - 1
173+
private inline val PhaseMask = (1 << PhaseWidth) - 1
176174
inline val MaxPossiblePhaseId = PhaseMask
175+
176+
private inline val NowhereCode = 0
177+
final val Nowhere: Period = new Period(NowhereCode)
178+
final val InitialPeriod: Period = Period(InitialRunId, FirstPhaseId)
179+
final val InvalidPeriod: Period = Period(NoRunId, NoPhaseId)
177180
}

compiler/src/dotty/tools/dotc/core/Phases.scala

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ import config.Printers.config
1212
import scala.collection.mutable.ListBuffer
1313
import dotty.tools.dotc.transform.MegaPhase.*
1414
import dotty.tools.dotc.transform.*
15-
import Periods.*
1615
import parsing.Parser
1716
import printing.XprintMode
1817
import typer.{TyperPhase, RefChecks}
@@ -64,7 +63,7 @@ object Phases {
6463
protected def run(using Context): Unit = unsupported("run")
6564
def transform(ref: SingleDenotation)(using Context): SingleDenotation =
6665
unsupported("transform")
67-
override def lastPhaseId(using Context): Int = id
66+
override def lastPhaseId(using Context): PhaseId = id
6867
}
6968

7069
final def phasePlan: List[List[Phase]] = this.phasesPlan
@@ -153,7 +152,7 @@ object Phases {
153152
nextDenotTransformerId = new Array[Int](phases.length)
154153
denotTransformers = new Array[DenotTransformer](phases.length)
155154

156-
var phaseId = 0
155+
var phaseId: PhaseId = 0
157156
def nextPhaseId = {
158157
phaseId += 1
159158
phaseId // starting from 1 as NoPhase is 0
@@ -260,7 +259,7 @@ object Phases {
260259
private var myGenBCodePhase: Phase = uninitialized
261260
private var myCheckCapturesPhase: Phase = uninitialized
262261

263-
private var myCheckCapturesPhaseId: Int = -2
262+
private var myCheckCapturesPhaseId: PhaseId = -2
264263
// -1 means undefined, 0 means NoPhase, we make sure that we don't get a false hit
265264
// if ctx.phaseId is either of these.
266265

@@ -292,7 +291,7 @@ object Phases {
292291
final def flattenPhase: Phase = myFlattenPhase
293292
final def genBCodePhase: Phase = myGenBCodePhase
294293
final def checkCapturesPhase: Phase = myCheckCapturesPhase
295-
final def checkCapturesPhaseId: Int = myCheckCapturesPhaseId
294+
final def checkCapturesPhaseId: PhaseId = myCheckCapturesPhaseId
296295

297296
private def setSpecificPhases() = {
298297
def phaseOfClass(pclass: Class[?]) = phases.find(pclass.isInstance).getOrElse(NoPhase)
@@ -461,7 +460,7 @@ object Phases {
461460
*/
462461
def printingContext(ctx: Context): Context = ctx
463462

464-
private var myPeriod: Period = Periods.InvalidPeriod
463+
private var myPeriod: Period = InvalidPeriod
465464
private var myBase: ContextBase = uninitialized
466465
private var myErasedTypes = false
467466
private var myFlatClasses = false
@@ -477,29 +476,29 @@ object Phases {
477476
* is reserved for NoPhase and the first real phase is at position 1.
478477
* -1 if the phase is not installed in the context.
479478
*/
480-
def id: Int = myPeriod.firstPhaseId
479+
def id: PhaseId = myPeriod.firstPhaseId
481480

482481
def period: Period = myPeriod
483-
def start: Int = myPeriod.firstPhaseId
484-
def end: Periods.PhaseId = myPeriod.lastPhaseId
482+
def start: PhaseId = myPeriod.firstPhaseId
483+
def end: PhaseId = myPeriod.lastPhaseId
485484

486485
final def erasedTypes: Boolean = myErasedTypes // Phase is after erasure
487486
final def flatClasses: Boolean = myFlatClasses // Phase is after flatten
488487
final def refChecked: Boolean = myRefChecked // Phase is after RefChecks
489488
final def lambdaLifted: Boolean = myLambdaLifted // Phase is after LambdaLift
490489
final def patternTranslated: Boolean = myPatternTranslated // Phase is after PatternMatcher
491490

492-
final def sameMembersStartId: Int = mySameMembersStartId
491+
final def sameMembersStartId: PhaseId = mySameMembersStartId
493492
// id of first phase where all symbols are guaranteed to have the same members as in this phase
494-
final def sameParentsStartId: Int = mySameParentsStartId
493+
final def sameParentsStartId: PhaseId = mySameParentsStartId
495494
// id of first phase where all symbols are guaranteed to have the same parents as in this phase
496-
final def sameBaseTypesStartId: Int = mySameBaseTypesStartId
495+
final def sameBaseTypesStartId: PhaseId = mySameBaseTypesStartId
497496
// id of first phase where all symbols are guaranteed to have the same base tpyes as in this phase
498497

499-
protected[Phases] def init(base: ContextBase, start: Int, end: Int): Unit = {
498+
protected[Phases] def init(base: ContextBase, start: PhaseId, end: PhaseId): Unit = {
500499
if (start >= FirstPhaseId)
501-
assert(myPeriod == Periods.InvalidPeriod, s"phase $this has already been used once; cannot be reused")
502-
assert(start <= Periods.MaxPossiblePhaseId, s"Too many phases, Period bits overflow")
500+
assert(myPeriod == InvalidPeriod, s"phase $this has already been used once; cannot be reused")
501+
assert(start <= MaxPossiblePhaseId, s"Too many phases, Period bits overflow")
503502
myBase = base
504503
myPeriod = Period(NoRunId, start, end)
505504
myErasedTypes = prev.getClass == classOf[Erasure] || prev.erasedTypes
@@ -512,7 +511,7 @@ object Phases {
512511
mySameBaseTypesStartId = if (changesBaseTypes) id else prev.sameBaseTypesStartId
513512
}
514513

515-
protected[Phases] def init(base: ContextBase, id: Int): Unit = init(base, id, id)
514+
protected[Phases] def init(base: ContextBase, id: PhaseId): Unit = init(base, id, id)
516515

517516
final def <=(that: Phase): Boolean =
518517
exists && id <= that.id
@@ -585,7 +584,7 @@ object Phases {
585584
def flattenPhase(using Context): Phase = ctx.base.flattenPhase
586585
def genBCodePhase(using Context): Phase = ctx.base.genBCodePhase
587586
def checkCapturesPhase(using Context): Phase = ctx.base.checkCapturesPhase
588-
def checkCapturesPhaseId(using Context): Int = ctx.base.checkCapturesPhaseId
587+
def checkCapturesPhaseId(using Context): PhaseId = ctx.base.checkCapturesPhaseId
589588

590589
def unfusedPhases(using Context): Array[Phase] = ctx.base.phases
591590

compiler/src/dotty/tools/dotc/core/SymDenotations.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2123,7 +2123,7 @@ object SymDenotations {
21232123
if (proceedWithEnter(sym, mscope)) {
21242124
enterNoReplace(sym, mscope)
21252125
val nxt = this.nextInRun
2126-
if (nxt.validFor.code > this.validFor.code)
2126+
if (nxt.validFor > this.validFor)
21272127
this.nextInRun.asSymDenotation.asClass.enter(sym)
21282128
}
21292129
}
@@ -2965,7 +2965,7 @@ object SymDenotations {
29652965
}
29662966

29672967
def isValidAt(phase: Phase)(using Context) =
2968-
checkedPeriod.code == ctx.period.code ||
2968+
checkedPeriod == ctx.period ||
29692969
createdAt.runId == ctx.runId &&
29702970
createdAt.phaseId < unfusedPhases.length &&
29712971
sameGroup(unfusedPhases(createdAt.phaseId), phase) &&

compiler/src/dotty/tools/dotc/core/Symbols.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ object Symbols extends SymUtils {
105105
/** The current denotation of this symbol */
106106
final def denot(using Context): SymDenotation = {
107107
util.Stats.record("Symbol.denot")
108-
if checkedPeriod.code == ctx.period.code then lastDenot
108+
if checkedPeriod == ctx.period then lastDenot
109109
else computeDenot(lastDenot)
110110
}
111111

0 commit comments

Comments
 (0)