Skip to content

Commit a50ec57

Browse files
committed
add WithAddedFilter to provide multiple filters
1 parent 78ed49b commit a50ec57

File tree

2 files changed

+99
-0
lines changed

2 files changed

+99
-0
lines changed

options.go

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,33 @@ func WithFilter(filter func(Model, Msg) Msg) ProgramOption {
229229
}
230230
}
231231

232+
// WithAddedFilter supplies an event filter that will be invoked before Bubble
233+
// Tea processes a tea.Msg. Multiple filters may be chained using this option.
234+
// Filters are invoked in order. Any filter that returns nil stops the remaining
235+
// filters form being invoked, and the event will be ignored.
236+
//
237+
// See WithFilter for more information about filtering.
238+
func WithAddedFilter(filter func(Model, Msg) Msg) ProgramOption {
239+
return func(p *Program) {
240+
prev := p.filter
241+
if prev == nil {
242+
WithFilter(filter)(p)
243+
return
244+
}
245+
246+
WithFilter(func(m Model, msg Msg) Msg {
247+
// Invoke the previous filter in the chain.
248+
msg = prev(m, msg)
249+
if msg == nil {
250+
// Previous filter ignored the event.
251+
return msg
252+
}
253+
254+
return filter(m, msg)
255+
})(p)
256+
}
257+
}
258+
232259
// WithFPS sets a custom maximum FPS at which the renderer should run. If
233260
// less than 1, the default value of 60 will be used. If over 120, the FPS
234261
// will be capped at 120.

options_test.go

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,4 +140,76 @@ func TestOptions(t *testing.T) {
140140
}
141141
}
142142
})
143+
144+
t.Run("multiple filters", func(t *testing.T) {
145+
type multiIncrement struct {
146+
num int
147+
}
148+
type eventuallyIncrementMsg incrementMsg
149+
150+
// This filter converts multiIncrement to a sequence of eventuallyIncrementMsg.
151+
a := func(m Model, msg Msg) Msg {
152+
if mul, ok := msg.(multiIncrement); ok {
153+
var cmds []Cmd
154+
for range mul.num {
155+
cmds = append(cmds, func() Msg {
156+
return eventuallyIncrementMsg{}
157+
})
158+
}
159+
return sequenceMsg(cmds)
160+
}
161+
return msg
162+
}
163+
164+
// This filter converts eventuallyIncrementMsg into incrementMsg.
165+
// If loaded out of order, the c filter breaks.
166+
b := func(_ Model, msg Msg) Msg {
167+
if msg, ok := msg.(eventuallyIncrementMsg); ok {
168+
return incrementMsg(msg)
169+
}
170+
return msg
171+
}
172+
173+
// This filter quits after 10 incrementMsg.
174+
// Requires the b filter to work.
175+
c := func(m Model, msg Msg) Msg {
176+
p := m.(*testModel)
177+
// Stop after 10 increments.
178+
if _, ok := msg.(incrementMsg); ok {
179+
if v := p.counter.Load(); v != nil && v.(int) >= 10 {
180+
return QuitMsg{}
181+
}
182+
}
183+
184+
return msg
185+
}
186+
187+
var (
188+
buf bytes.Buffer
189+
in bytes.Buffer
190+
m = &testModel{}
191+
)
192+
p := NewProgram(m,
193+
// The combination of filters a, b, and c in this test causes the test
194+
// to correctly quit at 10 increments.
195+
196+
// Convert into multiple eventuallyIncrementMsg.
197+
WithAddedFilter(a),
198+
// Convert into incrementMsg.
199+
WithAddedFilter(b),
200+
// Quit when the number of messages reaches 10.
201+
WithAddedFilter(c),
202+
203+
WithInput(&in),
204+
WithOutput(&buf))
205+
go p.Send(multiIncrement{num: 20})
206+
207+
if _, err := p.Run(); err != nil {
208+
t.Fatal(err)
209+
}
210+
211+
if m.counter.Load().(int) != 10 {
212+
t.Fatalf("counter should be 10, got %d", m.counter.Load())
213+
}
214+
})
143215
}

0 commit comments

Comments
 (0)