Skip to content

Commit a8e824f

Browse files
committed
StateV2: Add primitives for working with the current time
Source-Commit: e85e67f297a87792136c8d2bd896e12122f31c95
1 parent 589e69c commit a8e824f

File tree

4 files changed

+478
-0
lines changed

4 files changed

+478
-0
lines changed

unstable/statev2/build.gradle.kts

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,15 @@ dependencies {
2121
compileOnly(libs.versions.universalcraft.map { "gg.essential:universalcraft-1.8.9-forge:$it" }) {
2222
attributes { attribute(common, true) }
2323
}
24+
25+
testImplementation(kotlin("test"))
26+
testImplementation(project(":"))
27+
}
28+
29+
tasks.test {
30+
useJUnitPlatform()
2431
}
32+
2533
tasks.compileKotlin.setJvmDefault("all")
2634

2735
kotlin.jvmToolchain {
Lines changed: 382 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,382 @@
1+
package gg.essential.elementa.state.v2
2+
3+
import java.time.Duration
4+
import java.time.Instant
5+
6+
class StateScheduler(val time: State<Instant>) {
7+
/**
8+
* Subscribes the given observer to be re-evaluated once the given [targetTime] is reached.
9+
*
10+
* If the given time has already been reached, the observer will not be re-evaluated and `true` will be returned.
11+
*/
12+
fun Observer.observe(targetTime: Instant): Boolean {
13+
// Fast-path for when called after target time has already been reached
14+
if (!targetTime.isAfter(time.getUntracked())) {
15+
return true
16+
}
17+
// Wrapped with `memo` because observeRoughly may wake up multiple times,
18+
// but we want the caller to only wake up once we're actually at the target time.
19+
return memo { observeRoughly(targetTime) }.invoke()
20+
}
21+
22+
private val triggerCache: Array<Pair<Instant, State<Boolean>>?> = arrayOfNulls(Long.SIZE_BITS)
23+
24+
/**
25+
* Subscribes the given observer to be re-evaluated, potentially multiple times, until [targetTime] is
26+
* reached, at which point `true` is returned.
27+
* If the target time has already passed, `true` is returned immediately.
28+
*/
29+
private fun Observer.observeRoughly(targetTime: Instant): Boolean {
30+
val now = time.getUntracked()
31+
val delay = Duration.between(now, targetTime)
32+
if (delay.isNegative || delay.isZero) {
33+
// target time has already passed, we're good to go
34+
return true
35+
} else if (delay <= Duration.ofMillis(8)) {
36+
// target time is very close, just subscribe to the main State directly, we're likely going to hit it by the
37+
// next update anyway
38+
time()
39+
return false
40+
}
41+
42+
// To avoid subscribing (and therefore having to re-compute on each update) thousands of `State` directly to the
43+
// main State, we'll have up to 64 different intermediate states at exponentially far distances in the future,
44+
// and we'll subscribe to one of those instead of the main state. That way the main state should at no point
45+
// have more than 64 state (plus however many short-lived ones were subscribed directly above) to update on each
46+
// tick.
47+
// Once the time of an intermediate state has come, it'll trigger all states subscribed to it, which will all
48+
// then re-sort themselves into closer intermediate states, which should overall result in O(log(n))
49+
// amortized runtime cost each tick instead of the O(n) cost a naive implementation would have.
50+
51+
// For managing the intermediate states, this implementation is looking at how many shared leading bits there
52+
// are between the current timestamp and the target timestamp (the more shared bits, the closer they are), and
53+
// then picking the intermediate state based on that value from an array of length 64.
54+
val nowMs = now.toEpochMilli()
55+
val targetMs = targetTime.toEpochMilli()
56+
val bitsMatching = nowMs.xor(targetMs).countLeadingZeroBits()
57+
val nextWakeupTime = Instant.ofEpochMilli(targetMs.and((-1L).ushr(bitsMatching + 1).inv()))
58+
59+
var trigger = triggerCache[bitsMatching]
60+
if (trigger == null || trigger.first != nextWakeupTime) {
61+
trigger = Pair(nextWakeupTime, memo {
62+
// FIXME States are currently not un-registered when they no longer have any subscribers, only once
63+
// garbage collections deletes them, this is usually good enough but here (and especially during tests)
64+
// it can result in a lot more than 64 states being registered to the root time source state. To avoid
65+
// that, we skip subscribing to the root time source state altogether when our target time has been
66+
// reached, thereby explicitly removing the subscription of this memo from the root time source state.
67+
// Should the State implementation ever be optimized to handle this itself, this can be simplified to:
68+
// timeSource() >= nextWakeupTime
69+
if (time.getUntracked() >= nextWakeupTime) {
70+
true
71+
} else {
72+
time()
73+
false
74+
}
75+
}.let {
76+
// FIXME while a single `memo` is functionality sufficient to decouple evaluation of the subscribers
77+
// from the time source State, a second `memo` is required due to current implementation details of
78+
// State, otherwise the runtime of updating the current time will still be O(n).
79+
memo { it() }
80+
})
81+
triggerCache[bitsMatching] = trigger
82+
}
83+
trigger.second.invoke()
84+
return false
85+
}
86+
87+
companion object {
88+
private val systemTime = mutableStateOf(Instant.EPOCH)
89+
90+
/**
91+
* Scheduler for the current system time as reported by [Instant.now].
92+
*
93+
* Needs to be regularly updated on the main thread via [updateSystemTime].
94+
*/
95+
@JvmStatic
96+
val forSystemTime: StateScheduler = StateScheduler(systemTime)
97+
98+
@JvmStatic
99+
fun updateSystemTime(now: Instant = Instant.now()) {
100+
systemTime.set(now)
101+
}
102+
}
103+
}
104+
105+
/**
106+
* Returns the time of the given [scheduler] as an [ObservedInstant] which will track operations applied to it and
107+
* subscribe the [Observer] to be re-evaluated when the result of any of these operations changes.
108+
*
109+
* Note that the [ObservedInstant] wraps the [Observer] and as such the same life-time restrictions apply to it.
110+
* In particular that means that the [ObservedInstant] or any [ObservedValue] derived from it MUST NOT become the
111+
* value of the [State], only concrete types, e.g. as returned by [ObservedValue.getValue], may.
112+
*
113+
* When performing more complex operations on the returned value, using [withSystemTime] may be more efficient.
114+
*/
115+
fun Observer.systemTime(scheduler: StateScheduler = StateScheduler.forSystemTime): ObservedInstant =
116+
withSystemTime(scheduler) { it }
117+
118+
/**
119+
* Runs the given [block] with the time of the given [scheduler] as an [ObservedInstant] which will track operations
120+
* applied to it and subscribe the [Observer] to be re-evaluated when the result of any of these operations changes.
121+
*
122+
* Note that the [ObservedInstant] wraps the [Observer] and as such the same life-time restrictions apply to it.
123+
* In particular that means that the [ObservedInstant] or any [ObservedValue] derived from it MUST NOT become the
124+
* value of the [State], only concrete types, e.g. as returned by [ObservedValue.getValue], may.
125+
*/
126+
fun <T> Observer.withSystemTime(
127+
scheduler: StateScheduler = StateScheduler.forSystemTime,
128+
block: (ObservedInstant) -> T,
129+
): T {
130+
val now = scheduler.time.getUntracked()
131+
132+
var delayedRegistration = true
133+
var nextWakeupTime = Instant.MAX
134+
val observableInstant = ObservedInstant(now) { wakeupTime ->
135+
if (wakeupTime <= now) return@ObservedInstant
136+
137+
if (wakeupTime > nextWakeupTime) return@ObservedInstant
138+
nextWakeupTime = wakeupTime
139+
140+
if (delayedRegistration) return@ObservedInstant
141+
with(scheduler) { observe(nextWakeupTime) }
142+
}
143+
144+
val result = block(observableInstant)
145+
146+
delayedRegistration = false
147+
if (nextWakeupTime != Instant.MAX) {
148+
with(scheduler) { observe(nextWakeupTime) }
149+
}
150+
151+
return result
152+
}
153+
154+
/** @see withSystemTime */
155+
fun <T> stateUsingSystemTime(block: Observer.(ObservedInstant) -> T) = State { withSystemTime { block(it) } }
156+
157+
/**
158+
* Wraps a value (e.g. a [Long]) and tracks all operations applied.
159+
* This allows the original owner of the value to know how it is used and, crucially, whether a different value would
160+
* give different results.
161+
*
162+
* Implementations will provide various utility methods to operate on the contained value in a tracked way.
163+
* Any time such an operation is performed on the value, [changesAt] is called with the nearest value(s) that
164+
* would give a different result.
165+
* E.g. If the value of a [ObservedLong] is 5 and the user calls `lessOrEqual(7)`, that call will return `true`
166+
* and `changesAt(8)` is called because 8 would be the closest value to return `false`.
167+
* If `lessOrEqual(3)` is called, `false` is returned and `changesAt(3)` is called because 3 would be the
168+
* closest value to return `true`.
169+
* If `toString()` is called, `"5"` is returned and `changesAt` is called once with 4 and once with 6 because
170+
* any change in either direction will give a different result.
171+
* Note that while for these simple examples, `changesAt` is only called exactly as often as necessary and with
172+
* the exact value at which the change happens, this is not a strict requirement; it may be called multiple times
173+
* and/or with values closer than the next change; the only requirement is that at least one call must be at or
174+
* closer than the point at which the change happens, such that if we always re-evaluate a computation at the
175+
* closest point, we won't miss the change.
176+
*
177+
* To get the underlying value, one may call [getValue]. Since the returned value can not be tracked any further from
178+
* that point on, this will cause the computation to be re-evaluated if the returned value changes in any way.
179+
* To get the value without subscribing to all changes, use [untracked] and manually call [changesAt] as appropriate.
180+
*/
181+
interface ObservedValue<T> {
182+
val untracked: T
183+
val changesAt: (T) -> Unit
184+
185+
fun getValue(): T
186+
}
187+
188+
/** An [ObservedValue] for [Instant]. */
189+
class ObservedInstant(override val untracked: Instant, override val changesAt: (Instant) -> Unit) : ObservedValue<Instant> {
190+
override fun getValue(): Instant {
191+
if (untracked != Instant.MIN) changesAt(untracked.minusNanos(1))
192+
if (untracked != Instant.MAX) changesAt(untracked.plusNanos(1))
193+
return untracked
194+
}
195+
196+
override fun toString(): String = getValue().toString()
197+
198+
fun toEpochMillis() = ObservedLong(untracked.toEpochMilli()) { changesAt(Instant.ofEpochMilli(it)) }
199+
200+
fun isBefore(other: Instant) = ObservedDuration.between(this, other).isPositive
201+
fun isAfter(other: Instant) = ObservedDuration.between(other, this).isPositive
202+
203+
fun since(startInclusive: Instant): ObservedDuration = ObservedDuration.between(startInclusive, this)
204+
fun until(endExclusive: Instant): ObservedDuration = ObservedDuration.between(this, endExclusive)
205+
}
206+
207+
/** An [ObservedValue] for [Duration]. */
208+
class ObservedDuration(override val untracked: Duration, override val changesAt: (Duration) -> Unit) : ObservedValue<Duration> {
209+
override fun getValue(): Duration {
210+
if (untracked != MIN_DURATION) changesAt(untracked.minusNanos(1))
211+
if (untracked != MAX_DURATION) changesAt(untracked.plusNanos(1))
212+
return untracked
213+
}
214+
215+
override fun toString(): String = getValue().toString()
216+
217+
fun toMillis(): ObservedLong =
218+
ObservedLong(untracked.toMillis()) { changesAt(Duration.ofMillis(it)) }
219+
220+
val isNegative
221+
get() = untracked.isNegative.also { changesAt(if (it) Duration.ZERO else SMALLEST_NEGATIVE_DURATION) }
222+
223+
val isPositive
224+
get() = (!untracked.isNegative && !untracked.isZero).also { changesAt(if (it) Duration.ZERO else SMALLEST_POSITIVE_DURATION) }
225+
226+
val isZero
227+
get() = !isNegative && !isPositive
228+
229+
companion object {
230+
private val MIN_DURATION = Duration.ofSeconds(Long.MIN_VALUE, 0)
231+
private val MAX_DURATION = Duration.ofSeconds(Long.MAX_VALUE, 999999999)
232+
private val SMALLEST_POSITIVE_DURATION = Duration.ofNanos(1)
233+
private val SMALLEST_NEGATIVE_DURATION = Duration.ofNanos(-1)
234+
235+
fun between(startInclusive: Instant, endExclusive: Instant) =
236+
ObservedDuration(Duration.between(startInclusive, endExclusive)) {}
237+
238+
fun between(startInclusive: ObservedInstant, endExclusive: Instant) =
239+
ObservedDuration(Duration.between(startInclusive.untracked, endExclusive)) { duration ->
240+
startInclusive.changesAt(endExclusive.minus(duration))
241+
}
242+
243+
fun between(startInclusive: Instant, endExclusive: ObservedInstant) =
244+
ObservedDuration(Duration.between(startInclusive, endExclusive.untracked)) { duration ->
245+
endExclusive.changesAt(startInclusive.plus(duration))
246+
}
247+
248+
fun between(startInclusive: ObservedInstant, endExclusive: ObservedInstant) =
249+
ObservedDuration(Duration.between(startInclusive.untracked, endExclusive.untracked)) { duration ->
250+
startInclusive.changesAt(endExclusive.untracked.minus(duration))
251+
endExclusive.changesAt(startInclusive.untracked.plus(duration))
252+
}
253+
}
254+
}
255+
256+
/** An [ObservedValue] for [Long]. */
257+
class ObservedLong(override val untracked: Long, override val changesAt: (Long) -> Unit) : ObservedValue<Long> {
258+
override fun getValue(): Long {
259+
changesAt(untracked - 1)
260+
changesAt(untracked + 1)
261+
return untracked
262+
}
263+
264+
override fun toString(): String = getValue().toString()
265+
266+
// Note: Intentionally not implementing `compareTo` because that has two cutoff points and so we'll likely
267+
// unnecessarily re-evaluate just 1 before we actually need to re-evaluate.
268+
fun equal(other: Long) =
269+
(untracked == other).also { if (it) getValue() else changesAt(other) }
270+
fun less(other: Long) =
271+
(untracked < other).also { changesAt(if (it) other else other - 1) }
272+
fun lessOrEqual(other: Long) =
273+
(untracked <= other).also { changesAt(if (it) other + 1 else other) }
274+
fun greater(other: Long) =
275+
!lessOrEqual(other)
276+
fun greaterOrEqual(other: Long) =
277+
!less(other)
278+
279+
operator fun unaryMinus(): ObservedLong =
280+
ObservedLong(-untracked) { changesAt(-it) }
281+
282+
operator fun plus(other: Long): ObservedLong =
283+
ObservedLong(untracked + other) { changesAt(it - other) }
284+
operator fun minus(other: Long): ObservedLong = plus(-other)
285+
286+
operator fun times(other: Long): ObservedLong {
287+
if (other == 0L) return ObservedLong(0) {}
288+
if (other < 0) return -times(-other)
289+
val oldResult = untracked * other
290+
return ObservedLong(oldResult) { newResult ->
291+
if (newResult > oldResult) changesAt((newResult + (other - 1)).floorDiv(other))
292+
if (newResult < oldResult) changesAt(newResult.floorDiv(other))
293+
}
294+
}
295+
operator fun div(other: Long): ObservedLong {
296+
if (other < 0) return -(this / -other)
297+
if (untracked < 0) return -(-this / other)
298+
val oldResult = untracked / other
299+
return ObservedLong(oldResult) { newResult ->
300+
if (newResult > oldResult) changesAt(newResult * other)
301+
if (newResult < oldResult) changesAt((newResult + 1) * other - 1)
302+
}
303+
}
304+
305+
private fun both(a: ObservedLong, b: ObservedLong): ObservedLong {
306+
assert(a.untracked == b.untracked)
307+
return ObservedLong(a.untracked) { a.changesAt(it); b.changesAt(it) }
308+
}
309+
310+
operator fun plus(other: ObservedLong): ObservedLong =
311+
both(this.plus(other.untracked), other.plus(this.untracked))
312+
operator fun minus(other: ObservedLong): ObservedLong =
313+
plus(-other)
314+
operator fun times(other: ObservedLong): ObservedLong =
315+
both(this.times(other.untracked), other.times(this.untracked))
316+
}
317+
318+
/**
319+
* Explores all possible return values of the given [func] when called with values from the range given by [bounds]
320+
* provided [observed] constructs an observed type like [ObservedLong].
321+
*/
322+
@ForTestingOnly
323+
fun <T : Comparable<T>, OT : ObservedValue<T>, R> explore(
324+
bounds: ClosedRange<T>,
325+
observed: (T, (T) -> Unit) -> OT,
326+
func: (OT) -> R,
327+
): List<R> {
328+
fun eval(arg: T): Triple<R, T, T> {
329+
var lowerNext = bounds.start
330+
var upperNext = bounds.endInclusive
331+
val observedArg = observed(arg) { next ->
332+
if (next < arg) {
333+
if (next > lowerNext) {
334+
lowerNext = next
335+
}
336+
} else if (next > arg) {
337+
if (next < upperNext) {
338+
upperNext = next
339+
}
340+
}
341+
}
342+
val value = try {
343+
func(observedArg)
344+
} catch (e: Exception) {
345+
throw AssertionError("Failed to evaluate func with $arg", e)
346+
}
347+
return Triple(value, lowerNext, upperNext)
348+
}
349+
350+
var prev = bounds.start
351+
var prevValue = func(observed(prev) {})
352+
val results = mutableListOf(prevValue)
353+
var curr = bounds.start
354+
while (true) {
355+
val (value, lowerNext, upperNext) = eval(curr)
356+
if (value != prevValue) {
357+
assert(prev <= lowerNext) {
358+
"When ran with $curr, func returned $value and " +
359+
"reported its closest lower change point to be $lowerNext, " +
360+
"however when run with $prev it produces $prevValue, which contradicts this."
361+
}
362+
results.add(value)
363+
val (lowerValue, _, _) = eval(lowerNext)
364+
assert(lowerValue == prevValue) {
365+
"When ran with $prev, func returned $prevValue and " +
366+
"reported its closest upper change point to be $curr, " +
367+
"however when run with $lowerNext it produces $lowerValue, which contradicts this."
368+
}
369+
}
370+
if (curr == bounds.endInclusive) {
371+
break
372+
}
373+
prev = curr
374+
prevValue = value
375+
curr = upperNext
376+
}
377+
return results
378+
}
379+
380+
/** Annotated members are meant for use in unit tests and provide no API/ABI stability guarantees. */
381+
@RequiresOptIn(level = RequiresOptIn.Level.ERROR)
382+
annotation class ForTestingOnly

0 commit comments

Comments
 (0)