Skip to content

Commit 95a09c7

Browse files
committed
feat: add option to generate command to only generate allowed message ids
1 parent 7798b54 commit 95a09c7

File tree

4 files changed

+359
-287
lines changed

4 files changed

+359
-287
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ func main() {
7575
It is possible to generate Go code from a `.dbc` file.
7676

7777
```
78-
$ go run go.einride.tech/can/cmd/cantool generate <dbc file root folder> <output folder>
78+
$ go run go.einride.tech/can/cmd/cantool generate <dbc file root folder> <output folder> [<allowed-message-ids>...]
7979
```
8080

8181
In order to generate Go code that makes sense, we currently perform some

cmd/cantool/main.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ func generateCommand(app *kingpin.Application) {
5252
Arg("output-dir", "output directory").
5353
Required().
5454
String()
55+
allowedMessageIds := command.
56+
Arg("allowed-message-ids", "optional filter of message-ids to compile").
57+
Uint32List()
5558
command.Action(func(_ *kingpin.ParseContext) error {
5659
return filepath.Walk(*inputDir, func(p string, i os.FileInfo, err error) error {
5760
if err != nil {
@@ -66,7 +69,7 @@ func generateCommand(app *kingpin.Application) {
6669
}
6770
outputFile := relPath + ".go"
6871
outputPath := filepath.Join(*outputDir, outputFile)
69-
return genGo(p, outputPath)
72+
return genGo(p, outputPath, *allowedMessageIds)
7073
})
7174
})
7275
}
@@ -143,15 +146,15 @@ func analyzers() []*analysis.Analyzer {
143146
}
144147
}
145148

146-
func genGo(inputFile, outputFile string) error {
149+
func genGo(inputFile, outputFile string, allowedMessageIds []uint32) error {
147150
if err := os.MkdirAll(filepath.Dir(outputFile), 0o755); err != nil {
148151
return err
149152
}
150153
input, err := os.ReadFile(inputFile)
151154
if err != nil {
152155
return err
153156
}
154-
result, err := generate.Compile(inputFile, input)
157+
result, err := generate.Compile(inputFile, input, generate.WithAllowedMessageIds(allowedMessageIds))
155158
if err != nil {
156159
return err
157160
}

internal/generate/compile.go

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@ type CompileResult struct {
1414
Warnings []error
1515
}
1616

17-
func Compile(sourceFile string, data []byte) (result *CompileResult, err error) {
17+
type CompileOption func(*compiler)
18+
19+
func Compile(sourceFile string, data []byte, options ...CompileOption) (result *CompileResult, err error) {
1820
p := dbc.NewParser(sourceFile, data)
1921
if err := p.Parse(); err != nil {
2022
return nil, fmt.Errorf("failed to parse DBC source file: %w", err)
@@ -24,12 +26,30 @@ func Compile(sourceFile string, data []byte) (result *CompileResult, err error)
2426
db: &descriptor.Database{SourceFile: sourceFile},
2527
defs: defs,
2628
}
29+
30+
for _, opt := range options {
31+
opt(c)
32+
}
33+
2734
c.collectDescriptors()
2835
c.addMetadata()
2936
c.sortDescriptors()
3037
return &CompileResult{Database: c.db, Warnings: c.warnings}, nil
3138
}
3239

40+
func WithAllowedMessageIds(ids []uint32) CompileOption {
41+
return func(c *compiler) {
42+
if ids == nil {
43+
return
44+
}
45+
46+
c.onlyCompileIds = make([]dbc.MessageID, 0, len(ids))
47+
for _, id := range ids {
48+
c.onlyCompileIds = append(c.onlyCompileIds, dbc.MessageID(id))
49+
}
50+
}
51+
}
52+
3353
type compileError struct {
3454
def dbc.Def
3555
reason string
@@ -40,9 +60,10 @@ func (e *compileError) Error() string {
4060
}
4161

4262
type compiler struct {
43-
db *descriptor.Database
44-
defs []dbc.Def
45-
warnings []error
63+
onlyCompileIds []dbc.MessageID
64+
db *descriptor.Database
65+
defs []dbc.Def
66+
warnings []error
4667
}
4768

4869
func (c *compiler) addWarning(warning error) {
@@ -55,7 +76,7 @@ func (c *compiler) collectDescriptors() {
5576
case *dbc.VersionDef:
5677
c.db.Version = def.Version
5778
case *dbc.MessageDef:
58-
if def.MessageID == dbc.IndependentSignalsMessageID {
79+
if c.skipCompile(def.MessageID) {
5980
continue // don't compile
6081
}
6182
message := &descriptor.Message{
@@ -101,7 +122,9 @@ func (c *compiler) addMetadata() {
101122
case *dbc.SignalValueTypeDef:
102123
signal, ok := c.db.Signal(def.MessageID.ToCAN(), string(def.SignalName))
103124
if !ok {
104-
c.addWarning(&compileError{def: def, reason: "no declared signal"})
125+
if !c.skipCompile(def.MessageID) {
126+
c.addWarning(&compileError{def: def, reason: "no declared signal"})
127+
}
105128
continue
106129
}
107130
switch def.SignalValueType {
@@ -121,7 +144,7 @@ func (c *compiler) addMetadata() {
121144
case *dbc.CommentDef:
122145
switch def.ObjectType {
123146
case dbc.ObjectTypeMessage:
124-
if def.MessageID == dbc.IndependentSignalsMessageID {
147+
if c.skipCompile(def.MessageID) {
125148
continue // don't compile
126149
}
127150
message, ok := c.db.Message(def.MessageID.ToCAN())
@@ -131,7 +154,7 @@ func (c *compiler) addMetadata() {
131154
}
132155
message.Description = def.Comment
133156
case dbc.ObjectTypeSignal:
134-
if def.MessageID == dbc.IndependentSignalsMessageID {
157+
if c.skipCompile(def.MessageID) {
135158
continue // don't compile
136159
}
137160
signal, ok := c.db.Signal(def.MessageID.ToCAN(), string(def.SignalName))
@@ -149,7 +172,7 @@ func (c *compiler) addMetadata() {
149172
node.Description = def.Comment
150173
}
151174
case *dbc.ValueDescriptionsDef:
152-
if def.MessageID == dbc.IndependentSignalsMessageID {
175+
if c.skipCompile(def.MessageID) {
153176
continue // don't compile
154177
}
155178
if def.ObjectType != dbc.ObjectTypeSignal {
@@ -167,6 +190,9 @@ func (c *compiler) addMetadata() {
167190
})
168191
}
169192
case *dbc.AttributeValueForObjectDef:
193+
if c.skipCompile(def.MessageID) {
194+
continue // don't compile
195+
}
170196
switch def.ObjectType {
171197
case dbc.ObjectTypeMessage:
172198
msg, ok := c.db.Message(def.MessageID.ToCAN())
@@ -225,3 +251,15 @@ func (c *compiler) sortDescriptors() {
225251
}
226252
}
227253
}
254+
255+
func (c *compiler) skipCompile(id dbc.MessageID) bool {
256+
if id == dbc.IndependentSignalsMessageID {
257+
return true
258+
}
259+
for _, allowedMessageId := range c.onlyCompileIds {
260+
if allowedMessageId == id {
261+
return false
262+
}
263+
}
264+
return len(c.onlyCompileIds) > 0
265+
}

0 commit comments

Comments
 (0)