Skip to content

Commit 56028e9

Browse files
authored
Expose hooks for SIP trunk and rule matching. (#983)
1 parent 0a12e2c commit 56028e9

File tree

6 files changed

+260
-27
lines changed

6 files changed

+260
-27
lines changed

.changeset/sip-match-report.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"github.com/livekit/protocol": minor
3+
---
4+
5+
Expose hooks for SIP trunk and rule matching.

sip/dispatchruleconflictreason_string.go

Lines changed: 23 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

sip/sip.go

Lines changed: 171 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ import (
3636
"github.com/livekit/protocol/utils/guid"
3737
)
3838

39+
//go:generate stringer -type TrunkFilteredReason -trimprefix TrunkFiltered
40+
//go:generate stringer -type TrunkConflictReason -trimprefix TrunkConflict
41+
//go:generate stringer -type DispatchRuleConflictReason -trimprefix DispatchRuleConflict
42+
3943
func NewCallID() string {
4044
return guid.New(utils.SIPCallPrefix)
4145
}
@@ -130,14 +134,14 @@ func printID(s string) string {
130134
// ValidateDispatchRules checks a set of dispatch rules for conflicts.
131135
//
132136
// Deprecated: use ValidateDispatchRulesIter
133-
func ValidateDispatchRules(rules []*livekit.SIPDispatchRuleInfo) error {
134-
_, err := ValidateDispatchRulesIter(iters.Slice(rules))
137+
func ValidateDispatchRules(rules []*livekit.SIPDispatchRuleInfo, opts ...MatchDispatchRuleOpt) error {
138+
_, err := ValidateDispatchRulesIter(iters.Slice(rules), opts...)
135139
return err
136140
}
137141

138142
// ValidateDispatchRulesIter checks a set of dispatch rules for conflicts.
139-
func ValidateDispatchRulesIter(it iters.Iter[*livekit.SIPDispatchRuleInfo]) (best *livekit.SIPDispatchRuleInfo, _ error) {
140-
it = NewDispatchRuleValidator().ValidateIter(it)
143+
func ValidateDispatchRulesIter(it iters.Iter[*livekit.SIPDispatchRuleInfo], opts ...MatchDispatchRuleOpt) (best *livekit.SIPDispatchRuleInfo, _ error) {
144+
it = NewDispatchRuleValidator(opts...).ValidateIter(it)
141145
defer it.Close()
142146
for {
143147
r, err := it.Next()
@@ -153,8 +157,14 @@ func ValidateDispatchRulesIter(it iters.Iter[*livekit.SIPDispatchRuleInfo]) (bes
153157
return best, nil
154158
}
155159

156-
func NewDispatchRuleValidator() *DispatchRuleValidator {
160+
func NewDispatchRuleValidator(opts ...MatchDispatchRuleOpt) *DispatchRuleValidator {
161+
var opt matchDispatchRuleOpts
162+
for _, fnc := range opts {
163+
fnc(&opt)
164+
}
165+
opt.defaults()
157166
return &DispatchRuleValidator{
167+
opt: opt,
158168
byRuleKey: make(map[dispatchRuleKey]*livekit.SIPDispatchRuleInfo),
159169
}
160170
}
@@ -166,6 +176,7 @@ type dispatchRuleKey struct {
166176
}
167177

168178
type DispatchRuleValidator struct {
179+
opt matchDispatchRuleOpts
169180
byRuleKey map[dispatchRuleKey]*livekit.SIPDispatchRuleInfo
170181
}
171182

@@ -193,6 +204,10 @@ func (v *DispatchRuleValidator) Validate(r *livekit.SIPDispatchRuleInfo) error {
193204
key := dispatchRuleKey{Pin: pin, Trunk: trunk, Number: normalizeNumber(number)}
194205
r2 := v.byRuleKey[key]
195206
if r2 != nil {
207+
v.opt.Conflict(r, r2, DispatchRuleConflictGeneric)
208+
if v.opt.AllowConflicts {
209+
continue
210+
}
196211
return twirp.NewErrorf(twirp.InvalidArgument, "Conflicting SIP Dispatch Rules: same Trunk+Number+PIN combination for for %q and %q",
197212
printID(r.SipDispatchRuleId), printID(r2.SipDispatchRuleId))
198213
}
@@ -226,10 +241,10 @@ func (v *dispatchRuleValidatorIter) Close() {
226241
// It returns an error if there are conflicting rules. Returns nil if no rules match.
227242
//
228243
// Deprecated: use MatchDispatchRuleIter
229-
func SelectDispatchRule(rules []*livekit.SIPDispatchRuleInfo, req *rpc.EvaluateSIPDispatchRulesRequest) (*livekit.SIPDispatchRuleInfo, error) {
244+
func SelectDispatchRule(rules []*livekit.SIPDispatchRuleInfo, req *rpc.EvaluateSIPDispatchRulesRequest, opts ...MatchDispatchRuleOpt) (*livekit.SIPDispatchRuleInfo, error) {
230245
// Sorting will do the selection for us. We already filtered out irrelevant ones in MatchDispatchRule and above.
231246
// Nil is fine here. We will report "no rules matched" later.
232-
return ValidateDispatchRulesIter(iters.Slice(rules))
247+
return ValidateDispatchRulesIter(iters.Slice(rules), opts...)
233248
}
234249

235250
// GetPinAndRoom returns a room name/prefix and the pin for a dispatch rule. Just a convenience wrapper.
@@ -270,9 +285,13 @@ func normalizeNumber(num string) string {
270285
return num
271286
}
272287

273-
func validateTrunkInbound(byInbound map[string]*livekit.SIPInboundTrunkInfo, t *livekit.SIPInboundTrunkInfo) error {
288+
func validateTrunkInbound(byInbound map[string]*livekit.SIPInboundTrunkInfo, t *livekit.SIPInboundTrunkInfo, opt *matchTrunkOpts) error {
274289
if len(t.AllowedNumbers) == 0 {
275290
if t2 := byInbound[""]; t2 != nil {
291+
opt.Conflict(t, t2, TrunkConflictCalledNumber)
292+
if opt.AllowConflicts {
293+
return nil
294+
}
276295
return twirp.NewErrorf(twirp.InvalidArgument, "Conflicting inbound SIP Trunks: %q and %q, using the same number(s) %s without AllowedNumbers set",
277296
printID(t.SipTrunkId), printID(t2.SipTrunkId), printNumbers(t.Numbers))
278297
}
@@ -282,6 +301,10 @@ func validateTrunkInbound(byInbound map[string]*livekit.SIPInboundTrunkInfo, t *
282301
inboundKey := normalizeNumber(num)
283302
t2 := byInbound[inboundKey]
284303
if t2 != nil {
304+
opt.Conflict(t, t2, TrunkConflictCallingNumber)
305+
if opt.AllowConflicts {
306+
continue
307+
}
285308
return twirp.NewErrorf(twirp.InvalidArgument, "Conflicting inbound SIP Trunks: %q and %q, using the same number(s) %s and AllowedNumber %q",
286309
printID(t.SipTrunkId), printID(t2.SipTrunkId), printNumbers(t.Numbers), num)
287310
}
@@ -294,13 +317,18 @@ func validateTrunkInbound(byInbound map[string]*livekit.SIPInboundTrunkInfo, t *
294317
// ValidateTrunks checks a set of trunks for conflicts.
295318
//
296319
// Deprecated: use ValidateTrunksIter
297-
func ValidateTrunks(trunks []*livekit.SIPInboundTrunkInfo) error {
298-
return ValidateTrunksIter(iters.Slice(trunks))
320+
func ValidateTrunks(trunks []*livekit.SIPInboundTrunkInfo, opts ...MatchTrunkOpt) error {
321+
return ValidateTrunksIter(iters.Slice(trunks), opts...)
299322
}
300323

301324
// ValidateTrunksIter checks a set of trunks for conflicts.
302-
func ValidateTrunksIter(it iters.Iter[*livekit.SIPInboundTrunkInfo]) error {
325+
func ValidateTrunksIter(it iters.Iter[*livekit.SIPInboundTrunkInfo], opts ...MatchTrunkOpt) error {
303326
defer it.Close()
327+
var opt matchTrunkOpts
328+
for _, fnc := range opts {
329+
fnc(&opt)
330+
}
331+
opt.defaults()
304332
byOutboundAndInbound := make(map[string]map[string]*livekit.SIPInboundTrunkInfo)
305333
for {
306334
t, err := it.Next()
@@ -315,7 +343,7 @@ func ValidateTrunksIter(it iters.Iter[*livekit.SIPInboundTrunkInfo]) error {
315343
byInbound = make(map[string]*livekit.SIPInboundTrunkInfo)
316344
byOutboundAndInbound[""] = byInbound
317345
}
318-
if err := validateTrunkInbound(byInbound, t); err != nil {
346+
if err := validateTrunkInbound(byInbound, t, &opt); err != nil {
319347
return err
320348
}
321349
} else {
@@ -325,7 +353,7 @@ func ValidateTrunksIter(it iters.Iter[*livekit.SIPInboundTrunkInfo]) error {
325353
byInbound = make(map[string]*livekit.SIPInboundTrunkInfo)
326354
byOutboundAndInbound[num] = byInbound
327355
}
328-
if err := validateTrunkInbound(byInbound, t); err != nil {
356+
if err := validateTrunkInbound(byInbound, t, &opt); err != nil {
329357
return err
330358
}
331359
}
@@ -410,18 +438,85 @@ func matchNumbers(num string, allowed []string) bool {
410438
// Returns nil if no rules matched or an error if there are conflicting definitions.
411439
//
412440
// Deprecated: use MatchTrunkIter
413-
func MatchTrunk(trunks []*livekit.SIPInboundTrunkInfo, srcIP netip.Addr, calling, called string) (*livekit.SIPInboundTrunkInfo, error) {
414-
return MatchTrunkIter(iters.Slice(trunks), srcIP, calling, called)
441+
func MatchTrunk(trunks []*livekit.SIPInboundTrunkInfo, srcIP netip.Addr, calling, called string, opts ...MatchTrunkOpt) (*livekit.SIPInboundTrunkInfo, error) {
442+
return MatchTrunkIter(iters.Slice(trunks), srcIP, calling, called, opts...)
443+
}
444+
445+
type matchTrunkOpts struct {
446+
AllowConflicts bool
447+
Filtered TrunkFilteredFunc
448+
Conflict TrunkConflictFunc
449+
}
450+
451+
func (opt *matchTrunkOpts) defaults() {
452+
if opt.Filtered == nil {
453+
opt.Filtered = func(_ *livekit.SIPInboundTrunkInfo, _ TrunkFilteredReason) {}
454+
}
455+
if opt.Conflict == nil {
456+
opt.Conflict = func(_, _ *livekit.SIPInboundTrunkInfo, _ TrunkConflictReason) {}
457+
}
458+
}
459+
460+
type MatchTrunkOpt func(opt *matchTrunkOpts)
461+
462+
type TrunkFilteredReason int
463+
464+
const (
465+
TrunkFilteredInvalid = TrunkFilteredReason(iota)
466+
TrunkFilteredCallingNumberDisallowed
467+
TrunkFilteredCalledNumberDisallowed
468+
TrunkFilteredSourceAddressDisallowed
469+
)
470+
471+
type TrunkFilteredFunc func(tr *livekit.SIPInboundTrunkInfo, reason TrunkFilteredReason)
472+
473+
// WithTrunkFiltered sets a callback that is called when selected Trunk(s) doesn't match the call.
474+
func WithTrunkFiltered(fnc TrunkFilteredFunc) MatchTrunkOpt {
475+
return func(opt *matchTrunkOpts) {
476+
opt.Filtered = fnc
477+
}
478+
}
479+
480+
type TrunkConflictReason int
481+
482+
const (
483+
TrunkConflictDefault = TrunkConflictReason(iota)
484+
TrunkConflictCalledNumber
485+
TrunkConflictCallingNumber
486+
)
487+
488+
type TrunkConflictFunc func(t1, t2 *livekit.SIPInboundTrunkInfo, reason TrunkConflictReason)
489+
490+
// WithAllowTrunkConflicts allows conflicting Trunk definitions by picking the first match.
491+
//
492+
// Using this option will prevent TrunkConflictFunc from firing, since the first match will be returned immediately.
493+
func WithAllowTrunkConflicts() MatchTrunkOpt {
494+
return func(opt *matchTrunkOpts) {
495+
opt.AllowConflicts = true
496+
}
497+
}
498+
499+
// WithTrunkConflict sets a callback that is called when two Trunks conflict.
500+
func WithTrunkConflict(fnc TrunkConflictFunc) MatchTrunkOpt {
501+
return func(opt *matchTrunkOpts) {
502+
opt.Conflict = fnc
503+
}
415504
}
416505

417506
// MatchTrunkIter finds a SIP Trunk definition matching the request.
418507
// Returns nil if no rules matched or an error if there are conflicting definitions.
419-
func MatchTrunkIter(it iters.Iter[*livekit.SIPInboundTrunkInfo], srcIP netip.Addr, calling, called string) (*livekit.SIPInboundTrunkInfo, error) {
508+
func MatchTrunkIter(it iters.Iter[*livekit.SIPInboundTrunkInfo], srcIP netip.Addr, calling, called string, opts ...MatchTrunkOpt) (*livekit.SIPInboundTrunkInfo, error) {
420509
defer it.Close()
510+
var opt matchTrunkOpts
511+
for _, fnc := range opts {
512+
fnc(&opt)
513+
}
514+
opt.defaults()
421515
var (
422-
selectedTrunk *livekit.SIPInboundTrunkInfo
423-
defaultTrunk *livekit.SIPInboundTrunkInfo
424-
defaultTrunkCnt int // to error in case there are multiple ones
516+
selectedTrunk *livekit.SIPInboundTrunkInfo
517+
defaultTrunk *livekit.SIPInboundTrunkInfo
518+
defaultTrunkPrev *livekit.SIPInboundTrunkInfo
519+
defaultTrunkCnt int // to error in case there are multiple ones
425520
)
426521
calledNorm := normalizeNumber(called)
427522
for {
@@ -433,24 +528,38 @@ func MatchTrunkIter(it iters.Iter[*livekit.SIPInboundTrunkInfo], srcIP netip.Add
433528
}
434529
// Do not consider it if number doesn't match.
435530
if !matchNumbers(calling, tr.AllowedNumbers) {
531+
opt.Filtered(tr, TrunkFilteredCallingNumberDisallowed)
436532
continue
437533
}
438534
if !matchAddrMasks(srcIP, tr.AllowedAddresses) {
535+
opt.Filtered(tr, TrunkFilteredSourceAddressDisallowed)
439536
continue
440537
}
441538
if len(tr.Numbers) == 0 {
442539
// Default/wildcard trunk.
540+
defaultTrunkPrev = defaultTrunk
443541
defaultTrunk = tr
444542
defaultTrunkCnt++
445543
} else {
446544
for _, num := range tr.Numbers {
447545
if normalizeNumber(num) == calledNorm {
448546
// Trunk specific to the number.
449547
if selectedTrunk != nil {
548+
opt.Conflict(selectedTrunk, tr, TrunkConflictCalledNumber)
549+
if opt.AllowConflicts {
550+
// This path is unreachable, since we pick the first trunk. Kept for completeness.
551+
continue
552+
}
450553
return nil, twirp.NewErrorf(twirp.FailedPrecondition, "Multiple SIP Trunks matched for %q", called)
451554
}
452555
selectedTrunk = tr
556+
if opt.AllowConflicts {
557+
// Pick the first match as soon as it's found. We don't care about conflicts.
558+
return selectedTrunk, nil
559+
}
453560
// Keep searching! We want to know if there are any conflicting Trunk definitions.
561+
} else {
562+
opt.Filtered(tr, TrunkFilteredCalledNumberDisallowed)
454563
}
455564
}
456565
}
@@ -459,7 +568,10 @@ func MatchTrunkIter(it iters.Iter[*livekit.SIPInboundTrunkInfo], srcIP netip.Add
459568
return selectedTrunk, nil
460569
}
461570
if defaultTrunkCnt > 1 {
462-
return nil, twirp.NewErrorf(twirp.FailedPrecondition, "Multiple default SIP Trunks matched for %q", called)
571+
opt.Conflict(defaultTrunk, defaultTrunkPrev, TrunkConflictDefault)
572+
if !opt.AllowConflicts {
573+
return nil, twirp.NewErrorf(twirp.FailedPrecondition, "Multiple default SIP Trunks matched for %q", called)
574+
}
463575
}
464576
// Could still be nil here.
465577
return defaultTrunk, nil
@@ -469,14 +581,49 @@ func MatchTrunkIter(it iters.Iter[*livekit.SIPInboundTrunkInfo], srcIP netip.Add
469581
// Trunk parameter can be nil, in which case only wildcard dispatch rules will be effective (ones without Trunk IDs).
470582
//
471583
// Deprecated: use MatchDispatchRuleIter
472-
func MatchDispatchRule(trunk *livekit.SIPInboundTrunkInfo, rules []*livekit.SIPDispatchRuleInfo, req *rpc.EvaluateSIPDispatchRulesRequest) (*livekit.SIPDispatchRuleInfo, error) {
473-
return MatchDispatchRuleIter(trunk, iters.Slice(rules), req)
584+
func MatchDispatchRule(trunk *livekit.SIPInboundTrunkInfo, rules []*livekit.SIPDispatchRuleInfo, req *rpc.EvaluateSIPDispatchRulesRequest, opts ...MatchDispatchRuleOpt) (*livekit.SIPDispatchRuleInfo, error) {
585+
return MatchDispatchRuleIter(trunk, iters.Slice(rules), req, opts...)
586+
}
587+
588+
type matchDispatchRuleOpts struct {
589+
AllowConflicts bool
590+
Conflict DispatchRuleConflictFunc
591+
}
592+
593+
func (opt *matchDispatchRuleOpts) defaults() {
594+
if opt.Conflict == nil {
595+
opt.Conflict = func(_, _ *livekit.SIPDispatchRuleInfo, _ DispatchRuleConflictReason) {}
596+
}
597+
}
598+
599+
type MatchDispatchRuleOpt func(opt *matchDispatchRuleOpts)
600+
601+
type DispatchRuleConflictReason int
602+
603+
const (
604+
DispatchRuleConflictGeneric = DispatchRuleConflictReason(iota)
605+
)
606+
607+
type DispatchRuleConflictFunc func(r1, r2 *livekit.SIPDispatchRuleInfo, reason DispatchRuleConflictReason)
608+
609+
// WithAllowDispatchRuleConflicts allows conflicting DispatchRule definitions.
610+
func WithAllowDispatchRuleConflicts() MatchDispatchRuleOpt {
611+
return func(opt *matchDispatchRuleOpts) {
612+
opt.AllowConflicts = true
613+
}
614+
}
615+
616+
// WithDispatchRuleConflict sets a callback that is called when two DispatchRules conflict.
617+
func WithDispatchRuleConflict(fnc DispatchRuleConflictFunc) MatchDispatchRuleOpt {
618+
return func(opt *matchDispatchRuleOpts) {
619+
opt.Conflict = fnc
620+
}
474621
}
475622

476623
// MatchDispatchRuleIter finds the best dispatch rule matching the request parameters. Returns an error if no rule matched.
477624
// Trunk parameter can be nil, in which case only wildcard dispatch rules will be effective (ones without Trunk IDs).
478-
func MatchDispatchRuleIter(trunk *livekit.SIPInboundTrunkInfo, rules iters.Iter[*livekit.SIPDispatchRuleInfo], req *rpc.EvaluateSIPDispatchRulesRequest) (*livekit.SIPDispatchRuleInfo, error) {
479-
rules = NewDispatchRuleValidator().ValidateIter(rules)
625+
func MatchDispatchRuleIter(trunk *livekit.SIPInboundTrunkInfo, rules iters.Iter[*livekit.SIPDispatchRuleInfo], req *rpc.EvaluateSIPDispatchRulesRequest, opts ...MatchDispatchRuleOpt) (*livekit.SIPDispatchRuleInfo, error) {
626+
rules = NewDispatchRuleValidator(opts...).ValidateIter(rules)
480627
defer rules.Close()
481628
// Trunk can still be nil here in case none matched or were defined.
482629
// This is still fine, but only in case we'll match exactly one wildcard dispatch rule.

0 commit comments

Comments
 (0)