Skip to content

Commit 1c093e9

Browse files
committed
kvcache: Remove special case for reservation mask
We currently short circuit generation of the cache mask and just generate an empty tensor of the correct size. However, in some cases, this can also skip a cast operation. This can result in the worst case graph being not fully worst case. We don't actually need the fast path for mask generation, so it's better to just use the normal code path.
1 parent a8d9c26 commit 1c093e9

File tree

1 file changed

+1
-11
lines changed

1 file changed

+1
-11
lines changed

kvcache/causal.go

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,6 @@ type Causal struct {
4040

4141
// ** current forward pass **
4242

43-
// curReserve indicates that this forward pass is only for
44-
// memory reservation and we should not update our metadata
45-
// based on it.
46-
curReserve bool
47-
4843
// the active layer for Get and Put
4944
curLayer int
5045

@@ -206,13 +201,12 @@ func (c *Causal) Close() {
206201
}
207202

208203
func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
209-
c.curReserve = reserve
210204
c.curBatchSize = len(batch.Positions)
211205
c.curSequences = batch.Sequences
212206
c.curPositions = batch.Positions
213207
c.opts.Except = nil
214208

215-
if !c.curReserve {
209+
if !reserve {
216210
c.updateSlidingWindow()
217211

218212
var err error
@@ -379,10 +373,6 @@ func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
379373

380374
length := c.curCellRange.max - c.curCellRange.min + 1
381375

382-
if c.curReserve {
383-
return ctx.Input().Empty(c.config.MaskDType, length, batchSize)
384-
}
385-
386376
mask := make([]float32, batchSize*length)
387377

388378
for i := range c.curBatchSize {

0 commit comments

Comments
 (0)