Skip to content

Commit f798fd6

Browse files
committed
Validate configuration
1 parent edc14d8 commit f798fd6

File tree

6 files changed

+258
-8
lines changed

6 files changed

+258
-8
lines changed

cmd/server.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,11 @@ func runServer(
6868
cfg *config.Server,
6969
) error {
7070

71+
err := cfg.Validate()
72+
if err != nil {
73+
return err
74+
}
75+
7176
httpProbe := prober.NewHTTP()
7277
{
7378
logger.Infof("setting up HTTP server")

pkg/config/config.go

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,50 @@ type Server struct {
6161
}
6262
}
6363

64+
func (c Server) Validate() error {
65+
if c.HTTP.ListenAddress == "" {
66+
return errors.New("http listen address must not be empty")
67+
}
68+
if c.HTTP.GracePeriod < 0 {
69+
return errors.New("http grace period must be greater than or equal to 0")
70+
}
71+
72+
if c.MQTT.ListenAddress == "" {
73+
return errors.New("mqtt listen address must not be empty")
74+
}
75+
if c.MQTT.GracePeriod < 0 {
76+
return errors.New("mqtt grace period must be greater than or equal to 0")
77+
}
78+
if c.MQTT.ReadTimeout < 0 {
79+
return errors.New("mqtt read timeout must be greater than or equal to 0")
80+
}
81+
if c.MQTT.WriteTimeout < 0 {
82+
return errors.New("mqtt write timeout must be greater than or equal to 0")
83+
}
84+
if c.MQTT.IdleTimeout < 0 {
85+
return errors.New("mqtt idle timeout must be greater than or equal to 0")
86+
}
87+
if c.MQTT.ReaderBufferSize < 0 {
88+
return errors.New("mqtt read buffer size must be greater than or equal to 0")
89+
}
90+
if c.MQTT.WriterBufferSize < 0 {
91+
return errors.New("mqtt write buffer size must be greater than or equal to 0")
92+
}
93+
if c.MQTT.Handler.Publish.Timeout < 0 {
94+
return errors.New("handler publish timeout must be greater than or equal to 0")
95+
}
96+
if c.MQTT.Publisher.Name == "" {
97+
return errors.New("publisher name must not be empty")
98+
}
99+
if c.MQTT.Publisher.Name == Kafka && c.MQTT.Publisher.Kafka.BootstrapServers == "" {
100+
return errors.New("kafka bootstrap servers must not be empty")
101+
}
102+
if c.MQTT.Publisher.Kafka.GracePeriod < 0 {
103+
return errors.New("kafka grace period must be greater than or equal to 0")
104+
}
105+
return nil
106+
}
107+
64108
type KafkaConfigArgs struct {
65109
conf kafka.ConfigMap
66110
}

pkg/config/config_test.go

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@ package config
33
import (
44
"github.com/confluentinc/confluent-kafka-go/kafka"
55
"github.com/stretchr/testify/assert"
6+
"regexp"
67
"testing"
8+
"time"
79
)
810

911
func TestKafkaConfArgs(t *testing.T) {
@@ -47,3 +49,104 @@ func TestKafkaConfArgs(t *testing.T) {
4749
})
4850
}
4951
}
52+
53+
func TestTopicMappings(t *testing.T) {
54+
tests := []struct {
55+
name string
56+
input string
57+
output TopicMappings
58+
err error
59+
}{
60+
{
61+
name: "Set parameters",
62+
input: "temperature=temperature, humidity=.*/humidity,brightness=.*brightness, temperature=^cool$",
63+
output: TopicMappings{
64+
Mappings: []TopicMapping{
65+
{
66+
Topic: "temperature",
67+
RegExp: regexp.MustCompile(`temperature`),
68+
},
69+
{
70+
Topic: "humidity",
71+
RegExp: regexp.MustCompile(`.*/humidity`),
72+
},
73+
{
74+
Topic: "brightness",
75+
RegExp: regexp.MustCompile(`.*brightness`),
76+
},
77+
{
78+
Topic: "temperature",
79+
RegExp: regexp.MustCompile(`^cool$`),
80+
},
81+
},
82+
},
83+
},
84+
}
85+
for _, tc := range tests {
86+
t.Run(tc.name, func(t *testing.T) {
87+
a := assert.New(t)
88+
89+
s := new(Server)
90+
err := s.MQTT.Publisher.Kafka.TopicMappings.Set(tc.input)
91+
a.Equal(tc.err, err)
92+
a.Equal(tc.output, s.MQTT.Publisher.Kafka.TopicMappings)
93+
})
94+
}
95+
}
96+
97+
func TestConfigValidation(t *testing.T) {
98+
tests := []struct {
99+
name string
100+
factory func() *Server
101+
err error
102+
}{
103+
{
104+
name: "noop publisher",
105+
factory: func() *Server {
106+
s := new(Server)
107+
s.HTTP.ListenAddress = "localhost:9090"
108+
s.MQTT.ListenAddress = "localhost:1883"
109+
s.MQTT.Publisher.Name = Noop
110+
return s
111+
},
112+
},
113+
{
114+
name: "kafka publisher",
115+
factory: func() *Server {
116+
s := new(Server)
117+
s.HTTP.ListenAddress = "localhost:9090"
118+
s.MQTT.ListenAddress = "localhost:1883"
119+
s.MQTT.Publisher.Name = Kafka
120+
s.MQTT.Publisher.Kafka.BootstrapServers = "localhost:9092"
121+
return s
122+
},
123+
},
124+
{
125+
name: "kafka publisher with params",
126+
factory: func() *Server {
127+
s := new(Server)
128+
s.HTTP.ListenAddress = "localhost:9090"
129+
s.HTTP.GracePeriod = 10 * time.Second
130+
s.MQTT.ListenAddress = "localhost:1883"
131+
s.MQTT.GracePeriod = 1 * time.Second
132+
s.MQTT.ReadTimeout = 1 * time.Second
133+
s.MQTT.WriteTimeout = 1 * time.Second
134+
s.MQTT.IdleTimeout = 1 * time.Second
135+
s.MQTT.ReaderBufferSize = 256
136+
s.MQTT.WriterBufferSize = 256
137+
s.MQTT.Publisher.Name = Kafka
138+
s.MQTT.Publisher.Kafka.BootstrapServers = "localhost:9092"
139+
s.MQTT.Publisher.Kafka.GracePeriod = 10 * time.Second
140+
return s
141+
},
142+
},
143+
}
144+
for _, tc := range tests {
145+
t.Run(tc.name, func(t *testing.T) {
146+
input := tc.factory()
147+
a := assert.New(t)
148+
err := input.Validate()
149+
a.Equal(tc.err, err)
150+
})
151+
}
152+
}

pkg/server/http/http_test.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
package http
2+
3+
import (
4+
"github.com/grepplabs/mqtt-proxy/pkg/log"
5+
"github.com/grepplabs/mqtt-proxy/pkg/prober"
6+
"github.com/prometheus/client_golang/prometheus"
7+
"github.com/stretchr/testify/assert"
8+
"testing"
9+
"time"
10+
)
11+
12+
func TestNew(t *testing.T) {
13+
a := assert.New(t)
14+
15+
logger := log.NewDefaultLogger()
16+
registry := prometheus.NewRegistry()
17+
serverProber := prober.NewHTTP()
18+
19+
server := New(logger, registry, serverProber,
20+
WithListen("0.0.0.0:1883"),
21+
WithGracePeriod(10*time.Second),
22+
)
23+
24+
a.NotNil(server.logger)
25+
a.NotNil(server.opts)
26+
a.NotNil(server.srv)
27+
a.NotNil(server.mux)
28+
a.Same(serverProber, server.prober)
29+
30+
a.Equal("0.0.0.0:1883", server.opts.listen)
31+
a.Equal(10*time.Second, server.opts.gracePeriod)
32+
33+
a.Same(server.mux, server.srv.Handler)
34+
a.Equal(server.opts.listen, server.srv.Addr)
35+
}

pkg/server/mqtt/mqtt.go

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,21 +20,24 @@ type Server struct {
2020
}
2121

2222
// New creates a new Server.
23-
func New(logger log.Logger, registry *prometheus.Registry, prober *prober.HTTPProbe, opts ...Option) *Server {
23+
func New(logger log.Logger, _ *prometheus.Registry, prober *prober.HTTPProbe, opts ...Option) *Server {
2424
options := options{}
2525
for _, o := range opts {
2626
o.apply(&options)
2727
}
2828
mux := mqttserver.NewServeMux(logger)
2929

3030
s := &mqttserver.Server{
31-
Network: options.network,
32-
Addr: options.listen,
33-
Handler: options.handler,
34-
ReadTimeout: options.readTimeout,
35-
WriteTimeout: options.writeTimeout,
36-
TLSConfig: options.tlsConfig,
37-
ErrorLog: logger,
31+
Network: options.network,
32+
Addr: options.listen,
33+
Handler: options.handler,
34+
ReadTimeout: options.readTimeout,
35+
WriteTimeout: options.writeTimeout,
36+
IdleTimeout: options.idleTimeout,
37+
WriterBufferSize: options.writerBufferSize,
38+
ReaderBufferSize: options.readerBufferSize,
39+
TLSConfig: options.tlsConfig,
40+
ErrorLog: logger,
3841
}
3942

4043
return &Server{

pkg/server/mqtt/mqtt_test.go

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
package mqtt
2+
3+
import (
4+
"crypto/tls"
5+
"github.com/grepplabs/mqtt-proxy/pkg/log"
6+
mqtthandler "github.com/grepplabs/mqtt-proxy/pkg/mqtt/handler"
7+
"github.com/grepplabs/mqtt-proxy/pkg/prober"
8+
"github.com/prometheus/client_golang/prometheus"
9+
"github.com/stretchr/testify/assert"
10+
"testing"
11+
"time"
12+
)
13+
14+
func TestNew(t *testing.T) {
15+
a := assert.New(t)
16+
17+
logger := log.NewDefaultLogger()
18+
registry := prometheus.NewRegistry()
19+
serverProber := prober.NewHTTP()
20+
tlsCfg := &tls.Config{}
21+
handler := &mqtthandler.MQTTHandler{}
22+
23+
server := New(logger, registry, serverProber,
24+
WithNetwork("tcp"),
25+
WithListen("0.0.0.0:1883"),
26+
WithReadTimeout(10*time.Second),
27+
WithWriteTimeout(11*time.Second),
28+
WithIdleTimeout(12*time.Second),
29+
WithReaderBufferSize(2048),
30+
WithWriterBufferSize(4096),
31+
WithTLSConfig(tlsCfg),
32+
WithHandler(handler),
33+
)
34+
35+
a.NotNil(server.logger)
36+
a.NotNil(server.opts)
37+
a.NotNil(server.srv)
38+
a.NotNil(server.mux)
39+
a.Same(serverProber, server.prober)
40+
41+
a.Equal("tcp", server.opts.network)
42+
a.Equal("0.0.0.0:1883", server.opts.listen)
43+
a.Equal(10*time.Second, server.opts.readTimeout)
44+
a.Equal(11*time.Second, server.opts.writeTimeout)
45+
a.Equal(12*time.Second, server.opts.idleTimeout)
46+
a.Equal(2048, server.opts.readerBufferSize)
47+
a.Equal(4096, server.opts.writerBufferSize)
48+
a.Equal(handler, server.opts.handler)
49+
50+
a.Equal("tcp", server.srv.Network)
51+
a.Equal("0.0.0.0:1883", server.srv.Addr)
52+
a.Equal(10*time.Second, server.srv.ReadTimeout)
53+
a.Equal(11*time.Second, server.srv.WriteTimeout)
54+
a.Equal(12*time.Second, server.srv.IdleTimeout)
55+
a.Equal(2048, server.srv.ReaderBufferSize)
56+
a.Equal(4096, server.srv.WriterBufferSize)
57+
a.NotNil(server.srv.ErrorLog)
58+
a.Equal(handler, server.srv.Handler)
59+
60+
}

0 commit comments

Comments
 (0)