Skip to content

Commit 531eb7e

Browse files
authored
feat(wren-launcher): support MySQL and Postgres data source for dbt-tool (#1876)
1 parent 66cb96e commit 531eb7e

File tree

6 files changed

+257
-26
lines changed

6 files changed

+257
-26
lines changed

wren-launcher/commands/dbt/README.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
# Requirement for DBT project
2+
This part outlines some requirements for the target dbt project:
3+
- Ensure the DBT project is qualified and generates the required files:
4+
- `catalog.json`
5+
- `manifest.json`
6+
Execute the following commands:
7+
```
8+
dbt build
9+
dbt docs generate
10+
```
11+
- Prepare the profile of the dbt project for the connection info of your database.
12+
- `profiles.yml`
13+
14+
115
# How to Support a New Data Source
216

317
This document outlines the steps required to add support for a new data source to the dbt project converter.

wren-launcher/commands/dbt/converter.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,18 @@ func ConvertDbtProjectCore(opts ConvertOptions) (*ConvertResult, error) {
154154
"format": typedDS.Format,
155155
},
156156
}
157+
case *WrenMysqlDataSource:
158+
wrenDataSource = map[string]interface{}{
159+
"type": "mysql",
160+
"properties": map[string]interface{}{
161+
"host": typedDS.Host,
162+
"port": typedDS.Port,
163+
"database": typedDS.Database,
164+
"user": typedDS.User,
165+
"password": typedDS.Password,
166+
"sslMode": typedDS.SslMode,
167+
},
168+
}
157169
default:
158170
pterm.Warning.Printf("Warning: Unsupported data source type: %s\n", ds.GetType())
159171
wrenDataSource = map[string]interface{}{

wren-launcher/commands/dbt/data_source.go

Lines changed: 131 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package dbt
33
import (
44
"fmt"
55
"path/filepath"
6+
"strconv"
67
"strings"
78

89
"github.com/pterm/pterm"
@@ -16,6 +17,7 @@ const (
1617
timestampType = "timestamp"
1718
doubleType = "double"
1819
booleanType = "boolean"
20+
postgresType = "postgres"
1921
)
2022

2123
// Constants for SQL data types
@@ -73,10 +75,12 @@ func FromDbtProfiles(profiles *DbtProfiles) ([]DataSource, error) {
7375
// convertConnectionToDataSource converts connection to corresponding DataSource based on connection type
7476
func convertConnectionToDataSource(conn DbtConnection, dbtHomePath, profileName, outputName string) (DataSource, error) {
7577
switch strings.ToLower(conn.Type) {
76-
case "postgres", "postgresql":
78+
case postgresType, "postgresql":
7779
return convertToPostgresDataSource(conn)
7880
case "duckdb":
7981
return convertToLocalFileDataSource(conn, dbtHomePath)
82+
case "mysql":
83+
return convertToMysqlDataSource(conn)
8084
default:
8185
// For unsupported database types, we can choose to ignore or return error
8286
// Here we choose to return nil and log a warning
@@ -87,19 +91,26 @@ func convertConnectionToDataSource(conn DbtConnection, dbtHomePath, profileName,
8791

8892
// convertToPostgresDataSource converts to PostgreSQL data source
8993
func convertToPostgresDataSource(conn DbtConnection) (*WrenPostgresDataSource, error) {
94+
// For PostgreSQL, prefer dbname over database field
95+
dbName := conn.DbName
96+
if dbName == "" {
97+
dbName = conn.Database
98+
}
99+
100+
pterm.Info.Printf("Converting Postgres data source: %s:%d/%s\n", conn.Host, conn.Port, dbName)
101+
port := strconv.Itoa(conn.Port)
102+
if conn.Port == 0 {
103+
port = "5432"
104+
}
105+
90106
ds := &WrenPostgresDataSource{
91107
Host: conn.Host,
92-
Port: conn.Port,
93-
Database: conn.Database,
108+
Port: port,
109+
Database: dbName,
94110
User: conn.User,
95111
Password: conn.Password,
96112
}
97113

98-
// If no port is specified, use PostgreSQL default port
99-
if ds.Port == 0 {
100-
ds.Port = 5432
101-
}
102-
103114
return ds, nil
104115
}
105116

@@ -143,6 +154,30 @@ func convertToLocalFileDataSource(conn DbtConnection, dbtHome string) (*WrenLoca
143154
}, nil
144155
}
145156

157+
func convertToMysqlDataSource(conn DbtConnection) (*WrenMysqlDataSource, error) {
158+
pterm.Info.Printf("Converting MySQL data source: %s:%d/%s\n", conn.Host, conn.Port, conn.Database)
159+
160+
sslMode := "ENABLED" // Default SSL mode
161+
if conn.SslDisable {
162+
sslMode = "DISABLED"
163+
}
164+
port := strconv.Itoa(conn.Port)
165+
if conn.Port == 0 {
166+
port = "3306"
167+
}
168+
169+
ds := &WrenMysqlDataSource{
170+
Host: conn.Host,
171+
Port: port,
172+
Database: conn.Database,
173+
User: conn.User,
174+
Password: conn.Password,
175+
SslMode: sslMode,
176+
}
177+
178+
return ds, nil
179+
}
180+
146181
type WrenLocalFileDataSource struct {
147182
Url string `json:"url"`
148183
Format string `json:"format"`
@@ -189,15 +224,15 @@ func (ds *WrenLocalFileDataSource) MapType(sourceType string) string {
189224

190225
type WrenPostgresDataSource struct {
191226
Host string `json:"host"`
192-
Port int `json:"port"`
227+
Port string `json:"port"`
193228
Database string `json:"database"`
194229
User string `json:"user"`
195230
Password string `json:"password"`
196231
}
197232

198233
// GetType implements DataSource interface
199234
func (ds *WrenPostgresDataSource) GetType() string {
200-
return "postgres"
235+
return postgresType
201236
}
202237

203238
// Validate implements DataSource interface
@@ -211,7 +246,14 @@ func (ds *WrenPostgresDataSource) Validate() error {
211246
if ds.User == "" {
212247
return fmt.Errorf("user cannot be empty")
213248
}
214-
if ds.Port <= 0 || ds.Port > 65535 {
249+
if ds.Port == "" {
250+
return fmt.Errorf("port must be specified")
251+
}
252+
port, err := strconv.Atoi(ds.Port)
253+
if err != nil {
254+
return fmt.Errorf("port must be a valid number")
255+
}
256+
if port <= 0 || port > 65535 {
215257
return fmt.Errorf("port must be between 1 and 65535")
216258
}
217259
return nil
@@ -222,6 +264,83 @@ func (ds *WrenPostgresDataSource) MapType(sourceType string) string {
222264
return sourceType
223265
}
224266

267+
type WrenMysqlDataSource struct {
268+
Database string `json:"database"`
269+
Host string `json:"host"`
270+
Password string `json:"password"`
271+
Port string `json:"port"`
272+
User string `json:"user"`
273+
SslCA string `json:"ssl_ca,omitempty"` // Optional SSL CA file for MySQL
274+
SslMode string `json:"ssl_mode,omitempty"` // Optional SSL mode for MySQL
275+
}
276+
277+
// GetType implements DataSource interface
278+
func (ds *WrenMysqlDataSource) GetType() string {
279+
return "mysql"
280+
}
281+
282+
// Validate implements DataSource interface
283+
func (ds *WrenMysqlDataSource) Validate() error {
284+
if ds.Host == "" {
285+
return fmt.Errorf("host cannot be empty")
286+
}
287+
if ds.Database == "" {
288+
return fmt.Errorf("database cannot be empty")
289+
}
290+
if ds.User == "" {
291+
return fmt.Errorf("user cannot be empty")
292+
}
293+
if ds.Port == "" {
294+
return fmt.Errorf("port must be specified")
295+
}
296+
port, err := strconv.Atoi(ds.Port)
297+
if err != nil {
298+
return fmt.Errorf("port must be a valid number")
299+
}
300+
if port <= 0 || port > 65535 {
301+
return fmt.Errorf("port must be between 1 and 65535")
302+
}
303+
return nil
304+
}
305+
306+
func (ds *WrenMysqlDataSource) MapType(sourceType string) string {
307+
// This method is not used in WrenMysqlDataSource, but required by DataSource interface
308+
sourceType = strings.ToUpper(sourceType)
309+
switch sourceType {
310+
case "CHAR":
311+
return "char"
312+
case "VARCHAR":
313+
return varcharType
314+
case "TEXT", "TINYTEXT", "MEDIUMTEXT", "LONGTEXT", "ENUM", "SET":
315+
return "text"
316+
case "BIT", "TINYINT":
317+
return "TINYINT"
318+
case "SMALLINT":
319+
return "SMALLINT"
320+
case "MEDIUMINT", "INT", "INTEGER":
321+
return "INTEGER"
322+
case "BIGINT":
323+
return "BIGINT"
324+
case "FLOAT", "DOUBLE":
325+
return "DOUBLE"
326+
case "DECIMAL", "NUMERIC":
327+
return "DECIMAL"
328+
case "DATE":
329+
return "DATE"
330+
case "DATETIME":
331+
return "DATETIME"
332+
case "TIMESTAMP":
333+
return "TIMESTAMPTZ"
334+
case "BOOLEAN", "BOOL":
335+
return "BOOLEAN"
336+
case "JSON":
337+
return "JSON"
338+
default:
339+
// Return the original type if no mapping is found
340+
return strings.ToLower(sourceType)
341+
}
342+
}
343+
225344
// GetActiveDataSources gets active data sources based on specified profile and target
226345
// If profileName is empty, it will use the first found profile
227346
// If targetName is empty, it will use the profile's default target
@@ -326,7 +445,7 @@ func (d *DefaultDataSource) MapType(sourceType string) string {
326445
case "integer", "int", "bigint", "int64":
327446
return "integer"
328447
case "varchar", "text", "string", "char":
329-
return "varchar"
448+
return varcharType
330449
case "timestamp", "datetime", "date":
331450
return "timestamp"
332451
case "double", "float", "decimal", "numeric":

0 commit comments

Comments
 (0)