Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 128 additions & 0 deletions accessors/dataflow/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
package dataflow

import (
"context"
"fmt"
"strconv"
"sync"
"time"

dataflow "cloud.google.com/go/dataflow/apiv1beta3"
"github.com/GoogleCloudPlatform/spanner-migration-tool/common/utils"
"github.com/GoogleCloudPlatform/spanner-migration-tool/profiles"
"github.com/GoogleCloudPlatform/spanner-migration-tool/streaming"
dataflowpb "cloud.google.com/go/dataflow/apiv1beta3/dataflowpb"
otherUtils "github.com/GoogleCloudPlatform/spanner-migration-tool/utils"
)

var (
// Default value for maxWorkers.
maxWorkers int32 = 50
// Default value for NumWorkers.
numWorkers int32 = 1
// Max allowed value for maxWorkers and numWorkers.
MAX_WORKER_LIMIT int32 = 1000
// Min allowed value for maxWorkers and numWorkers.
MIN_WORKER_LIMIT int32 = 1
)

var once sync.Once

var dataflowClient *dataflow.FlexTemplatesClient

func GetInstance(ctx context.Context) *dataflow.FlexTemplatesClient {
if dataflowClient == nil {
once.Do(func() {
dataflowClient, _ = dataflow.NewFlexTemplatesClient(ctx)
})
return dataflowClient
}
return dataflowClient
}

func LaunchDataflowJob(ctx context.Context, dataflowClient *dataflow.FlexTemplatesClient, streamingCfg streaming.StreamingCfg, spannerDetails profiles.TargetProfile, inputFilePattern string, gcpProjectId string) (string, string, error) {
project, instance, dbName, _ := spannerDetails.GetResourceIds(ctx, time.Now(), "", nil)
dataflowCfg := streamingCfg.DataflowCfg
fmt.Println("Launching dataflow job ", dataflowCfg.JobName, " in ", project, "-", dataflowCfg.Location)
var dataflowHostProjectId string
if streamingCfg.DataflowCfg.HostProjectId == "" {
dataflowHostProjectId, _ = utils.GetProject()
} else {
dataflowHostProjectId = streamingCfg.DataflowCfg.HostProjectId
}

dataflowSubnetwork := ""

// If custom network is not selected, use public IP. Typical for internal testing flow.
workerIpAddressConfig := dataflowpb.WorkerIPAddressConfiguration_WORKER_IP_PUBLIC

if streamingCfg.DataflowCfg.Network != "" {
workerIpAddressConfig = dataflowpb.WorkerIPAddressConfiguration_WORKER_IP_PRIVATE
if streamingCfg.DataflowCfg.Subnetwork == "" {
return "", "", fmt.Errorf("if network is specified, subnetwork cannot be empty")
} else {
dataflowSubnetwork = fmt.Sprintf("https://www.googleapis.com/compute/v1/projects/%s/regions/%s/subnetworks/%s", dataflowHostProjectId, streamingCfg.DataflowCfg.Location, dataflowCfg.Subnetwork)
}
}

if streamingCfg.DataflowCfg.MaxWorkers != "" {
intVal, err := strconv.ParseInt(streamingCfg.DataflowCfg.MaxWorkers, 10, 64)
if err != nil {
return "", "", fmt.Errorf("could not parse MaxWorkers parameter %s, please provide a positive integer as input", streamingCfg.DataflowCfg.MaxWorkers)
}
maxWorkers = int32(intVal)
if maxWorkers < MIN_WORKER_LIMIT || maxWorkers > MAX_WORKER_LIMIT {
return "", "", fmt.Errorf("maxWorkers should lie in the range [%d, %d]", MIN_WORKER_LIMIT, MAX_WORKER_LIMIT)
}
}
if streamingCfg.DataflowCfg.NumWorkers != "" {
intVal, err := strconv.ParseInt(streamingCfg.DataflowCfg.NumWorkers, 10, 64)
if err != nil {
return "", "", fmt.Errorf("could not parse NumWorkers parameter %s, please provide a positive integer as input", dataflowCfg.NumWorkers)
}
numWorkers = int32(intVal)
if numWorkers < MIN_WORKER_LIMIT || numWorkers > MAX_WORKER_LIMIT {
return "", "", fmt.Errorf("numWorkers should lie in the range [%d, %d]", MIN_WORKER_LIMIT, MAX_WORKER_LIMIT)
}
}
launchParameters := &dataflowpb.LaunchFlexTemplateParameter{
JobName: streamingCfg.DataflowCfg.JobName,
Template: &dataflowpb.LaunchFlexTemplateParameter_ContainerSpecGcsPath{ContainerSpecGcsPath: "gs://dataflow-templates-southamerica-west1/2023-09-12-00_RC00/flex/Cloud_Datastream_to_Spanner"},
Parameters: map[string]string{
"inputFilePattern": otherUtils.ConcatDirectoryPath(inputFilePattern, "data"),
"streamName": fmt.Sprintf("projects/%s/locations/%s/streams/%s", gcpProjectId, streamingCfg.DatastreamCfg.StreamLocation, streamingCfg.DatastreamCfg.StreamId),
"instanceId": instance,
"databaseId": dbName,
"sessionFilePath": streamingCfg.TmpDir + "session.json",
"deadLetterQueueDirectory": inputFilePattern + "dlq",
"transformationContextFilePath": streamingCfg.TmpDir + "transformationContext.json",
"gcsPubSubSubscription": fmt.Sprintf("projects/%s/subscriptions/%s", gcpProjectId, streamingCfg.PubsubCfg.SubscriptionId),
},
Environment: &dataflowpb.FlexTemplateRuntimeEnvironment{
MaxWorkers: maxWorkers,
NumWorkers: numWorkers,
ServiceAccountEmail: streamingCfg.DataflowCfg.ServiceAccountEmail,
AutoscalingAlgorithm: 2, // 2 corresponds to AUTOSCALING_ALGORITHM_BASIC
EnableStreamingEngine: true,
Network: streamingCfg.DataflowCfg.Network,
Subnetwork: dataflowSubnetwork,
IpConfiguration: workerIpAddressConfig,
},
}
req := &dataflowpb.LaunchFlexTemplateRequest{
ProjectId: gcpProjectId,
LaunchParameter: launchParameters,
Location: streamingCfg.DataflowCfg.Location,
}
fmt.Println("Created flex template request body...")

respDf, err := dataflowClient.LaunchFlexTemplate(ctx, req)
if err != nil {
fmt.Printf("flexTemplateRequest: %+v\n", req)
return "", "", fmt.Errorf("unable to launch template: %v", err)
}
gcloudDfCmd := utils.GetGcloudDataflowCommand(req)
fmt.Printf("\nEquivalent gCloud command for job %s:\n%s\n\n", req.LaunchParameter.JobName, gcloudDfCmd)
return respDf.Job.Id, gcloudDfCmd, nil

}
221 changes: 221 additions & 0 deletions accessors/datastream/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
package datastream

import (
"context"
"fmt"
"sync"

datastream "cloud.google.com/go/datastream/apiv1"
datastreampb "cloud.google.com/go/datastream/apiv1/datastreampb"
"github.com/GoogleCloudPlatform/spanner-migration-tool/profiles"
"github.com/GoogleCloudPlatform/spanner-migration-tool/streaming"
"github.com/GoogleCloudPlatform/spanner-migration-tool/utils/constants"
"google.golang.org/protobuf/types/known/fieldmaskpb"
)

var once sync.Once

var datastreamClient *datastream.Client

func GetInstance(ctx context.Context) *datastream.Client {
if datastreamClient == nil {
once.Do(func() {
datastreamClient, _ = datastream.NewClient(ctx)
})
return datastreamClient
}
return datastreamClient
}

func GetGCSPathFromConnectionProfile(ctx context.Context, datastreamClient *datastream.Client, projectID string, datastreamDestinationConnCfg streaming.DstConnCfg) (*datastreampb.GcsProfile, error) {
dstProf := fmt.Sprintf("projects/%s/locations/%s/connectionProfiles/%s", projectID, datastreamDestinationConnCfg.Location, datastreamDestinationConnCfg.Name)
res, err := datastreamClient.GetConnectionProfile(ctx, &datastreampb.GetConnectionProfileRequest{Name: dstProf})
if err != nil {
return nil, fmt.Errorf("could not get connection profiles: %v", err)
}
// Fetch the GCS path from the target connection profile.
gcsProfile := res.Profile.(*datastreampb.ConnectionProfile_GcsProfile).GcsProfile
return gcsProfile, nil
}

func CreateStream(ctx context.Context, datastreamClient *datastream.Client, gcpProjectId string, prefix string, sourceDatabaseType string, sourceDatabaseName string, datastreamCfg streaming.DatastreamCfg) error {
gcsDstCfg := &datastreampb.GcsDestinationConfig{
Path: prefix,
FileFormat: &datastreampb.GcsDestinationConfig_AvroFileFormat{},
}
srcCfg := &datastreampb.SourceConfig{
SourceConnectionProfile: fmt.Sprintf("projects/%s/locations/%s/connectionProfiles/%s", gcpProjectId, datastreamCfg.SourceConnectionConfig.Location, datastreamCfg.SourceConnectionConfig.Name),
}
var dbList []profiles.LogicalShard
dbList = append(dbList, profiles.LogicalShard{DbName: sourceDatabaseName})
err := getSourceStreamConfig(srcCfg, sourceDatabaseType, dbList, datastreamCfg)
if err != nil {
return fmt.Errorf("could not get source stream config: %v", err)
}

dstCfg := &datastreampb.DestinationConfig{
DestinationConnectionProfile: fmt.Sprintf("projects/%s/locations/%s/connectionProfiles/%s", gcpProjectId, datastreamCfg.DestinationConnectionConfig.Location, datastreamCfg.DestinationConnectionConfig.Name),
DestinationStreamConfig: &datastreampb.DestinationConfig_GcsDestinationConfig{GcsDestinationConfig: gcsDstCfg},
}
streamInfo := &datastreampb.Stream{
DisplayName: datastreamCfg.StreamDisplayName,
SourceConfig: srcCfg,
DestinationConfig: dstCfg,
State: datastreampb.Stream_RUNNING,
BackfillStrategy: &datastreampb.Stream_BackfillAll{BackfillAll: &datastreampb.Stream_BackfillAllStrategy{}},
}
createStreamRequest := &datastreampb.CreateStreamRequest{
Parent: fmt.Sprintf("projects/%s/locations/%s", gcpProjectId, datastreamCfg.StreamLocation),
StreamId: datastreamCfg.StreamId,
Stream: streamInfo,
}

fmt.Println("Created stream request..")

dsOp, err := datastreamClient.CreateStream(ctx, createStreamRequest)
if err != nil {
fmt.Printf("cannot create stream: createStreamRequest: %+v\n", createStreamRequest)
return fmt.Errorf("cannot create stream: %v ", err)
}

_, err = dsOp.Wait(ctx)
if err != nil {
fmt.Printf("datastream create operation failed: createStreamRequest: %+v\n", createStreamRequest)
return fmt.Errorf("datastream create operation failed: %v", err)
}
fmt.Println("Successfully created stream ", datastreamCfg.StreamId)
return nil
}

func UpdateStream(ctx context.Context, datastreamClient *datastream.Client, gcpProjectId string, prefix string, datastreamCfg streaming.DatastreamCfg) error {
gcsDstCfg := &datastreampb.GcsDestinationConfig{
Path: prefix,
FileFormat: &datastreampb.GcsDestinationConfig_AvroFileFormat{},
}
srcCfg := &datastreampb.SourceConfig{
SourceConnectionProfile: fmt.Sprintf("projects/%s/locations/%s/connectionProfiles/%s", gcpProjectId, datastreamCfg.SourceConnectionConfig.Location, datastreamCfg.SourceConnectionConfig.Name),
}
dstCfg := &datastreampb.DestinationConfig{
DestinationConnectionProfile: fmt.Sprintf("projects/%s/locations/%s/connectionProfiles/%s", gcpProjectId, datastreamCfg.DestinationConnectionConfig.Location, datastreamCfg.DestinationConnectionConfig.Name),
DestinationStreamConfig: &datastreampb.DestinationConfig_GcsDestinationConfig{GcsDestinationConfig: gcsDstCfg},
}

streamInfo := &datastreampb.Stream{
DisplayName: datastreamCfg.StreamDisplayName,
SourceConfig: srcCfg,
DestinationConfig: dstCfg,
State: datastreampb.Stream_RUNNING,
BackfillStrategy: &datastreampb.Stream_BackfillAll{BackfillAll: &datastreampb.Stream_BackfillAllStrategy{}},
}
streamInfo.Name = fmt.Sprintf("projects/%s/locations/%s/streams/%s", gcpProjectId, datastreamCfg.StreamLocation, datastreamCfg.StreamId)
updateStreamRequest := &datastreampb.UpdateStreamRequest{
UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{"state"}},
Stream: streamInfo,
}
upOp, err := datastreamClient.UpdateStream(ctx, updateStreamRequest)
if err != nil {
return fmt.Errorf("could not create update request: %v", err)
}
_, err = upOp.Wait(ctx)
if err != nil {
return fmt.Errorf("update stream operation failed: %v", err)
}
fmt.Println("Done")
return nil
}

// dbName is the name of the database to be migrated.
// tabeList is the common list of tables that need to be migrated from each database
func getMysqlSourceStreamConfig(dbList []profiles.LogicalShard, tableList []string) *datastreampb.SourceConfig_MysqlSourceConfig {
mysqlTables := []*datastreampb.MysqlTable{}
for _, table := range tableList {
includeTable := &datastreampb.MysqlTable{
Table: table,
}
mysqlTables = append(mysqlTables, includeTable)
}
includeDbList := []*datastreampb.MysqlDatabase{}
for _, db := range dbList {
//create include db object
includeDb := &datastreampb.MysqlDatabase{
Database: db.DbName,
MysqlTables: mysqlTables,
}
includeDbList = append(includeDbList, includeDb)
}
//TODO: Clean up fmt.Printf logs and replace them with zap logger.
fmt.Printf("Include DB List for datastream: %+v\n", includeDbList)
mysqlSrcCfg := &datastreampb.MysqlSourceConfig{
IncludeObjects: &datastreampb.MysqlRdbms{MysqlDatabases: includeDbList},
MaxConcurrentBackfillTasks: 50,
}
return &datastreampb.SourceConfig_MysqlSourceConfig{MysqlSourceConfig: mysqlSrcCfg}
}

func getOracleSourceStreamConfig(dbName string, tableList []string) *datastreampb.SourceConfig_OracleSourceConfig {
oracleTables := []*datastreampb.OracleTable{}
for _, table := range tableList {
includeTable := &datastreampb.OracleTable{
Table: table,
}
oracleTables = append(oracleTables, includeTable)
}
oracledb := &datastreampb.OracleSchema{
Schema: dbName,
OracleTables: oracleTables,
}
oracleSrcCfg := &datastreampb.OracleSourceConfig{
IncludeObjects: &datastreampb.OracleRdbms{OracleSchemas: []*datastreampb.OracleSchema{oracledb}},
MaxConcurrentBackfillTasks: 50,
}
return &datastreampb.SourceConfig_OracleSourceConfig{OracleSourceConfig: oracleSrcCfg}
}

func getPostgreSQLSourceStreamConfig(properties string) (*datastreampb.SourceConfig_PostgresqlSourceConfig, error) {
params, err := profiles.ParseMap(properties)
if err != nil {
return nil, fmt.Errorf("could not parse properties: %v", err)
}
var excludeObjects []*datastreampb.PostgresqlSchema
for _, s := range []string{"information_schema", "postgres", "pg_catalog", "pg_temp_1", "pg_toast", "pg_toast_temp_1"} {
excludeObjects = append(excludeObjects, &datastreampb.PostgresqlSchema{
Schema: s,
})
}
replicationSlot, replicationSlotExists := params["replicationSlot"]
publication, publicationExists := params["publication"]
if !replicationSlotExists || !publicationExists {
return nil, fmt.Errorf("replication slot or publication not specified")
}
postgresSrcCfg := &datastreampb.PostgresqlSourceConfig{
ExcludeObjects: &datastreampb.PostgresqlRdbms{PostgresqlSchemas: excludeObjects},
ReplicationSlot: replicationSlot,
Publication: publication,
MaxConcurrentBackfillTasks: 50,
}
return &datastreampb.SourceConfig_PostgresqlSourceConfig{PostgresqlSourceConfig: postgresSrcCfg}, nil
}

func getSourceStreamConfig(srcCfg *datastreampb.SourceConfig, sourceDatabaseType string, dbList []profiles.LogicalShard, datastreamCfg streaming.DatastreamCfg) error {
switch sourceDatabaseType {
case constants.MYSQL:
// For MySQL, it supports sharded migrations and batching databases in a physical machine into a single
//Datastream, so dbList is passed.
srcCfg.SourceStreamConfig = getMysqlSourceStreamConfig(dbList, datastreamCfg.TableList)
return nil
case constants.ORACLE:
// For Oracle, no sharded migrations or db batching support, so the dbList always contains only one element.
srcCfg.SourceStreamConfig = getOracleSourceStreamConfig(dbList[0].DbName, datastreamCfg.TableList)
return nil
case constants.POSTGRES:
// For Postgres, tables need to be configured at the schema level, which will require more information List<Dbs> and Map<Schema, List<Tables>>
// instead of List<Dbs> and List<Tables>. Becuase of this we do not configure postgres datastream at individual table level currently.
sourceStreamConfig, err := getPostgreSQLSourceStreamConfig(datastreamCfg.Properties)
if err == nil {
srcCfg.SourceStreamConfig = sourceStreamConfig
}
return err
default:
return fmt.Errorf("only MySQL, Oracle and PostgreSQL are supported as source streams")
}
}
Loading