Skip to content

Commit ec5489c

Browse files
authored
fix: simplify type assertions (#159)
Signed-off-by: Chris Gianelloni <[email protected]>
1 parent 3007f34 commit ec5489c

File tree

1 file changed

+93
-83
lines changed

1 file changed

+93
-83
lines changed

cek/machine.go

Lines changed: 93 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,6 @@ var (
3838
// getCompute returns a Compute state from the pool
3939
func getCompute[T syn.Eval]() *Compute[T] {
4040
c := computePool.Get().(*Compute[syn.DeBruijn])
41-
if debug {
42-
// Runtime type assertion for safety in debug builds
43-
_ = (*Compute[T])(unsafe.Pointer(c))
44-
}
4541
return (*Compute[T])(unsafe.Pointer(c))
4642
}
4743

@@ -57,10 +53,6 @@ func putCompute[T syn.Eval](c *Compute[T]) {
5753
// getReturn returns a Return state from the pool
5854
func getReturn[T syn.Eval]() *Return[T] {
5955
r := returnPool.Get().(*Return[syn.DeBruijn])
60-
if debug {
61-
// Runtime type assertion for safety in debug builds
62-
_ = (*Return[T])(unsafe.Pointer(r))
63-
}
6456
return (*Return[T])(unsafe.Pointer(r))
6557
}
6658

@@ -69,24 +61,20 @@ func putReturn[T syn.Eval](r *Return[T]) {
6961
// Reset the state
7062
r.Ctx = nil
7163
r.Value = nil
72-
returnPool.Put(r)
64+
returnPool.Put((*Return[syn.DeBruijn])(unsafe.Pointer(r)))
7365
}
7466

7567
// getDone returns a Done state from the pool
7668
func getDone[T syn.Eval]() *Done[T] {
7769
d := donePool.Get().(*Done[syn.DeBruijn])
78-
if debug {
79-
// Runtime type assertion for safety in debug builds
80-
_ = (*Done[T])(unsafe.Pointer(d))
81-
}
8270
return (*Done[T])(unsafe.Pointer(d))
8371
}
8472

8573
// putDone returns a Done state to the pool
8674
func putDone[T syn.Eval](d *Done[T]) {
8775
// Reset the state
8876
d.term = nil
89-
donePool.Put(d)
77+
donePool.Put((*Done[syn.DeBruijn])(unsafe.Pointer(d)))
9078
}
9179

9280
type Machine[T syn.Eval] struct {
@@ -140,10 +128,13 @@ func (m *Machine[T]) Run(term syn.Term[T]) (syn.Term[T], error) {
140128
return nil, err
141129
}
142130

143-
var state MachineState[T] = getCompute[T]()
144-
state.(*Compute[T]).Ctx = &NoFrame{}
145-
state.(*Compute[T]).Env = nil
146-
state.(*Compute[T]).Term = term
131+
var state MachineState[T]
132+
133+
comp := getCompute[T]()
134+
comp.Ctx = &NoFrame{}
135+
comp.Env = nil
136+
comp.Term = term
137+
state = comp
147138

148139
for {
149140
switch v := state.(type) {
@@ -200,9 +191,10 @@ func (m *Machine[T]) compute(
200191
return nil, errors.New("open term evaluated")
201192
}
202193

203-
state = getReturn[T]()
204-
state.(*Return[T]).Ctx = context
205-
state.(*Return[T]).Value = value
194+
ret := getReturn[T]()
195+
ret.Ctx = context
196+
ret.Value = value
197+
state = ret
206198
case *syn.Delay[T]:
207199
if err := m.stepAndMaybeSpend(ExDelay); err != nil {
208200
return nil, err
@@ -213,9 +205,10 @@ func (m *Machine[T]) compute(
213205
Env: env,
214206
}
215207

216-
state = getReturn[T]()
217-
state.(*Return[T]).Ctx = context
218-
state.(*Return[T]).Value = value
208+
ret := getReturn[T]()
209+
ret.Ctx = context
210+
ret.Value = value
211+
state = ret
219212
case *syn.Lambda[T]:
220213
if err := m.stepAndMaybeSpend(ExLambda); err != nil {
221214
return nil, err
@@ -227,9 +220,10 @@ func (m *Machine[T]) compute(
227220
Env: env,
228221
}
229222

230-
state = getReturn[T]()
231-
state.(*Return[T]).Ctx = context
232-
state.(*Return[T]).Value = value
223+
ret := getReturn[T]()
224+
ret.Ctx = context
225+
ret.Value = value
226+
state = ret
233227
case *syn.Apply[T]:
234228
if err := m.stepAndMaybeSpend(ExApply); err != nil {
235229
return nil, err
@@ -241,20 +235,22 @@ func (m *Machine[T]) compute(
241235
Ctx: context,
242236
}
243237

244-
state = getCompute[T]()
245-
state.(*Compute[T]).Ctx = frame
246-
state.(*Compute[T]).Env = env
247-
state.(*Compute[T]).Term = t.Function
238+
comp := getCompute[T]()
239+
comp.Ctx = frame
240+
comp.Env = env
241+
comp.Term = t.Function
242+
state = comp
248243
case *syn.Constant:
249244
if err := m.stepAndMaybeSpend(ExConstant); err != nil {
250245
return nil, err
251246
}
252247

253-
state = getReturn[T]()
254-
state.(*Return[T]).Ctx = context
255-
state.(*Return[T]).Value = &Constant{
248+
ret := getReturn[T]()
249+
ret.Ctx = context
250+
ret.Value = &Constant{
256251
Constant: t.Con,
257252
}
253+
state = ret
258254
case *syn.Force[T]:
259255
if err := m.stepAndMaybeSpend(ExForce); err != nil {
260256
return nil, err
@@ -264,10 +260,11 @@ func (m *Machine[T]) compute(
264260
Ctx: context,
265261
}
266262

267-
state = getCompute[T]()
268-
state.(*Compute[T]).Ctx = frame
269-
state.(*Compute[T]).Env = env
270-
state.(*Compute[T]).Term = t.Term
263+
comp := getCompute[T]()
264+
comp.Ctx = frame
265+
comp.Env = env
266+
comp.Term = t.Term
267+
state = comp
271268
case *syn.Error:
272269
return nil, errors.New("error explicitly called")
273270

@@ -276,13 +273,14 @@ func (m *Machine[T]) compute(
276273
return nil, err
277274
}
278275

279-
state = getReturn[T]()
280-
state.(*Return[T]).Ctx = context
281-
state.(*Return[T]).Value = &Builtin[T]{
276+
ret := getReturn[T]()
277+
ret.Ctx = context
278+
ret.Value = &Builtin[T]{
282279
Func: t.DefaultFunction,
283280
Args: nil,
284281
Forces: 0,
285282
}
283+
state = ret
286284
case *syn.Constr[T]:
287285
if err := m.stepAndMaybeSpend(ExConstr); err != nil {
288286
return nil, err
@@ -291,12 +289,13 @@ func (m *Machine[T]) compute(
291289
fields := t.Fields
292290

293291
if len(fields) == 0 {
294-
state = getReturn[T]()
295-
state.(*Return[T]).Ctx = context
296-
state.(*Return[T]).Value = &Constr[T]{
292+
ret := getReturn[T]()
293+
ret.Ctx = context
294+
ret.Value = &Constr[T]{
297295
Tag: t.Tag,
298296
Fields: []Value[T]{},
299297
}
298+
state = ret
300299
} else {
301300
first_field := fields[0]
302301

@@ -310,10 +309,11 @@ func (m *Machine[T]) compute(
310309
Env: env,
311310
}
312311

313-
state = getCompute[T]()
314-
state.(*Compute[T]).Ctx = frame
315-
state.(*Compute[T]).Env = env
316-
state.(*Compute[T]).Term = first_field
312+
comp := getCompute[T]()
313+
comp.Ctx = frame
314+
comp.Env = env
315+
comp.Term = first_field
316+
state = comp
317317
}
318318
case *syn.Case[T]:
319319
if err := m.stepAndMaybeSpend(ExCase); err != nil {
@@ -326,10 +326,11 @@ func (m *Machine[T]) compute(
326326
Branches: t.Branches,
327327
}
328328

329-
state = getCompute[T]()
330-
state.(*Compute[T]).Ctx = frame
331-
state.(*Compute[T]).Env = env
332-
state.(*Compute[T]).Term = t.Constr
329+
comp := getCompute[T]()
330+
comp.Ctx = frame
331+
comp.Env = env
332+
comp.Term = t.Constr
333+
state = comp
333334
default:
334335
panic(fmt.Sprintf("unknown term: %T: %v", term, term))
335336
}
@@ -355,13 +356,14 @@ func (m *Machine[T]) returnCompute(
355356
return nil, err
356357
}
357358
case *FrameAwaitFunTerm[T]:
358-
state = getCompute[T]()
359-
state.(*Compute[T]).Ctx = &FrameAwaitArg[T]{
359+
comp := getCompute[T]()
360+
comp.Ctx = &FrameAwaitArg[T]{
360361
Ctx: c.Ctx,
361362
Value: value,
362363
}
363-
state.(*Compute[T]).Env = c.Env
364-
state.(*Compute[T]).Term = c.Term
364+
comp.Env = c.Env
365+
comp.Term = c.Term
366+
state = comp
365367
case *FrameAwaitFunValue[T]:
366368
state, err = m.applyEvaluate(c.Ctx, value, c.Value)
367369
if err != nil {
@@ -378,12 +380,13 @@ func (m *Machine[T]) returnCompute(
378380
fields := c.Fields
379381

380382
if len(fields) == 0 {
381-
state = getReturn[T]()
382-
state.(*Return[T]).Ctx = c.Ctx
383-
state.(*Return[T]).Value = &Constr[T]{
383+
ret := getReturn[T]()
384+
ret.Ctx = c.Ctx
385+
ret.Value = &Constr[T]{
384386
Tag: c.Tag,
385387
Fields: resolvedFields,
386388
}
389+
state = ret
387390
} else {
388391
first_field := fields[0]
389392
rest := fields[1:]
@@ -396,10 +399,11 @@ func (m *Machine[T]) returnCompute(
396399
Env: c.Env,
397400
}
398401

399-
state = getCompute[T]()
400-
state.(*Compute[T]).Ctx = frame
401-
state.(*Compute[T]).Env = c.Env
402-
state.(*Compute[T]).Term = first_field
402+
comp := getCompute[T]()
403+
comp.Ctx = frame
404+
comp.Env = c.Env
405+
comp.Term = first_field
406+
state = comp
403407
}
404408
case *FrameCases[T]:
405409
switch v := value.(type) {
@@ -408,10 +412,11 @@ func (m *Machine[T]) returnCompute(
408412
return nil, errors.New("MaxIntExceeded")
409413
}
410414
if indexExists(c.Branches, int(v.Tag)) {
411-
state = getCompute[T]()
412-
state.(*Compute[T]).Ctx = transferArgStack(v.Fields, c.Ctx)
413-
state.(*Compute[T]).Env = c.Env
414-
state.(*Compute[T]).Term = c.Branches[v.Tag]
415+
comp := getCompute[T]()
416+
comp.Ctx = transferArgStack(v.Fields, c.Ctx)
417+
comp.Env = c.Env
418+
comp.Term = c.Branches[v.Tag]
419+
state = comp
415420
} else {
416421
return nil, errors.New("MissingCaseBranch")
417422
}
@@ -425,8 +430,9 @@ func (m *Machine[T]) returnCompute(
425430
}
426431
}
427432

428-
state = getDone[T]()
429-
state.(*Done[T]).term = dischargeValue[T](value)
433+
done := getDone[T]()
434+
done.term = dischargeValue[T](value)
435+
state = done
430436
default:
431437
panic(fmt.Sprintf("unknown context %v", context))
432438
}
@@ -446,10 +452,11 @@ func (m *Machine[T]) forceEvaluate(
446452

447453
switch v := value.(type) {
448454
case *Delay[T]:
449-
state = getCompute[T]()
450-
state.(*Compute[T]).Ctx = context
451-
state.(*Compute[T]).Env = v.Env
452-
state.(*Compute[T]).Term = v.Body
455+
comp := getCompute[T]()
456+
comp.Ctx = context
457+
comp.Env = v.Env
458+
comp.Term = v.Body
459+
state = comp
453460
case *Builtin[T]:
454461
if v.NeedsForce() {
455462
var resolved Value[T]
@@ -467,9 +474,10 @@ func (m *Machine[T]) forceEvaluate(
467474
resolved = b
468475
}
469476

470-
state = getReturn[T]()
471-
state.(*Return[T]).Ctx = context
472-
state.(*Return[T]).Value = resolved
477+
ret := getReturn[T]()
478+
ret.Ctx = context
479+
ret.Value = resolved
480+
state = ret
473481
} else {
474482
return nil, errors.New("BuiltinTermArgumentExpected")
475483
}
@@ -491,10 +499,11 @@ func (m *Machine[T]) applyEvaluate(
491499
case *Lambda[T]:
492500
env := f.Env.Extend(arg)
493501

494-
state = getCompute[T]()
495-
state.(*Compute[T]).Ctx = context
496-
state.(*Compute[T]).Env = env
497-
state.(*Compute[T]).Term = f.Body
502+
comp := getCompute[T]()
503+
comp.Ctx = context
504+
comp.Env = env
505+
comp.Term = f.Body
506+
state = comp
498507
case *Builtin[T]:
499508
if !f.NeedsForce() && f.IsArrow() {
500509
var resolved Value[T]
@@ -512,9 +521,10 @@ func (m *Machine[T]) applyEvaluate(
512521
resolved = b
513522
}
514523

515-
state = getReturn[T]()
516-
state.(*Return[T]).Ctx = context
517-
state.(*Return[T]).Value = resolved
524+
ret := getReturn[T]()
525+
ret.Ctx = context
526+
ret.Value = resolved
527+
state = ret
518528
} else {
519529
return nil, errors.New("UnexpectedBuiltinTermArgument")
520530
}

0 commit comments

Comments
 (0)