Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions api/workflow/worker_ce.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,12 @@ func StartWorkflow(cfg *config.Config) error {
if err != nil {
return err
}
temporalClient, err := client.Dial(client.Options{
client, err := temporal.NewClient(client.Options{
HostPort: cfg.WorkFLow.Endpoint,
})
}, "csghub-api")
if err != nil {
return fmt.Errorf("unable to create workflow client, error: %w", err)
}
client, err := temporal.NewClient(temporalClient)
if err != nil {
return err
}
Expand Down
35 changes: 21 additions & 14 deletions builder/temporal/temporal.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@ package temporal

import (
"context"
"log/slog"
"fmt"

"go.opentelemetry.io/otel"
"go.temporal.io/sdk/client"
"go.temporal.io/sdk/log"
"go.temporal.io/sdk/contrib/opentelemetry"
"go.temporal.io/sdk/interceptor"
"go.temporal.io/sdk/worker"
)

Expand All @@ -28,25 +30,30 @@ type clientImpl struct {

var _client Client = &clientImpl{}

func NewClient(temporalClient client.Client) (*clientImpl, error) {
func NewClient(options client.Options, serviceName string) (*clientImpl, error) {
tracingInterceptor, err := opentelemetry.NewTracingInterceptor(opentelemetry.TracerOptions{
Tracer: otel.Tracer(serviceName),
})
if err != nil {
return nil, fmt.Errorf("temporal otel interceptor %w", err)
}
options.Interceptors = []interceptor.ClientInterceptor{tracingInterceptor}

t, err := client.Dial(options)
if err != nil {
return nil, err
}
c := _client.(*clientImpl)
c.Client = temporalClient
c.Client = t

return c, nil
}

func NewClientByHostPort(hostPort string) (*clientImpl, error) {
logger := log.NewStructuredLogger(slog.Default())
temporalClient, err := client.Dial(client.Options{
HostPort: hostPort,
Logger: logger,
})
if err != nil {
return nil, err
}
// used in test only
func Assign(temporalClient client.Client) {
c := _client.(*clientImpl)
c.Client = temporalClient
return c, nil

}

func (c *clientImpl) NewWorker(queue string, options worker.Options) worker.Registry {
Expand Down
5 changes: 2 additions & 3 deletions builder/temporal/temporal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,14 @@ func TestTemporalClient(t *testing.T) {
c := ts.GetDefaultClient()

tester := &Tester{client: temporal.GetClient()}
_, err := temporal.NewClient(c)
require.NoError(t, err)
temporal.Assign(c)

worker1 := tester.client.NewWorker("q1", worker.Options{})
worker1.RegisterWorkflow(tester.Count)
worker2 := tester.client.NewWorker("q2", worker.Options{})
worker2.RegisterWorkflow(tester.Add)

err = tester.client.Start()
err := tester.client.Start()
require.NoError(t, err)

r, err := tester.client.ExecuteWorkflow(context.TODO(), client.StartWorkflowOptions{
Expand Down
14 changes: 8 additions & 6 deletions cmd/csghub-server/cmd/dataviewer/launch.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"log/slog"

"github.com/spf13/cobra"
"go.temporal.io/sdk/client"
"opencsg.com/csghub-server/api/httpbase"
"opencsg.com/csghub-server/builder/instrumentation"
"opencsg.com/csghub-server/builder/store/database"
Expand Down Expand Up @@ -34,17 +35,18 @@ var launchCmd = &cobra.Command{
}
database.InitDB(dbConfig)

tc, err := temporal.NewClientByHostPort(cfg.WorkFLow.Endpoint)
if err != nil {
return fmt.Errorf("build workflow client, error: %w", err)
}

stopOtel, err := instrumentation.SetupOTelSDK(context.Background(), cfg, "dataviewer-api")
if err != nil {
panic(err)
}

r, err := router.NewDataViewerRouter(cfg, tc)
client, err := temporal.NewClient(client.Options{
HostPort: cfg.WorkFLow.Endpoint,
}, "dataset-viewer")
if err != nil {
return fmt.Errorf("unable to create workflow client, error: %w", err)
}
r, err := router.NewDataViewerRouter(cfg, client)
if err != nil {
return fmt.Errorf("failed to init dataviewer router: %w", err)
}
Expand Down
6 changes: 4 additions & 2 deletions dataviewer/component/callback.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,10 @@ func (c *callbackComponentImpl) TriggerDataviewUpdateWorkflow(ctx context.Contex
WorkflowExecutionTimeout: executeTimeOut,
WorkflowTaskTimeout: taskTimeout,
}
wfRun, err := c.workflowClient.ExecuteWorkflow(ctx, options, workflows.DataViewerUpdateWorkflow,
dvCom.WorkflowUpdateParams{Req: req, Config: c.cfg})
wfRun, err := c.workflowClient.ExecuteWorkflow(
context.Background(), options, workflows.DataViewerUpdateWorkflow,
dvCom.WorkflowUpdateParams{Req: req, Config: c.cfg},
)
if err != nil {
return nil, fmt.Errorf("fail to execute workflow, error: %w", err)
}
Expand Down
31 changes: 19 additions & 12 deletions dataviewer/workflows/activity.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"
"io"
"log/slog"
"math"
"net/http"
"net/url"
"os"
Expand Down Expand Up @@ -99,7 +100,7 @@ func (dva *dataViewerActivityImpl) GetCardFromReadme(ctx context.Context, req ty
Path: types.REPOCARD_FILENAME,
RepoType: req.RepoType,
}
f, err := dva.gitServer.GetRepoFileContents(context.Background(), fileReq)
f, err := dva.gitServer.GetRepoFileContents(ctx, fileReq)
if err != nil {
slog.Warn("get repo branch readme.md content error", slog.Any("fileReq", fileReq), slog.Any("err", err))
return &card, nil
Expand Down Expand Up @@ -127,15 +128,6 @@ func (dva *dataViewerActivityImpl) GetCardFromReadme(ctx context.Context, req ty
}

func (dva *dataViewerActivityImpl) ScanRepoFiles(ctx context.Context, scanParam dvCom.ScanRepoFileReq) (*dvCom.RepoFilesClass, error) {
repoReq := dvCom.RepoFilesReq{
Namespace: scanParam.Req.Namespace,
RepoName: scanParam.Req.Name,
RepoType: scanParam.Req.RepoType,
Ref: scanParam.Req.Branch,
Folder: "",
GSTree: dva.gitServer.GetRepoFileTree,
TotalLimitSize: scanParam.ConvertLimitSize,
}
fileClass := dvCom.RepoFilesClass{
AllFiles: make(map[string]*dvCom.RepoFile),
ParquetFiles: make(map[string]*dvCom.RepoFile),
Expand All @@ -144,9 +136,24 @@ func (dva *dataViewerActivityImpl) ScanRepoFiles(ctx context.Context, scanParam
TotalJsonSize: 0,
TotalCsvSize: 0,
}
err := GetFilePaths(repoReq, &fileClass)

resp, err := dva.gitServer.GetTree(ctx, types.GetTreeRequest{
Namespace: scanParam.Req.Namespace,
Name: scanParam.Req.Name,
RepoType: scanParam.Req.RepoType,
Ref: scanParam.Req.Branch,
Recursive: true,
Limit: math.MaxInt,
})
if err != nil {
return nil, fmt.Errorf("scan repo file error: %w", err)
return nil, err
}

for _, file := range resp.Files {
if file.Type == "dir" {
continue
}
appendFile(file, &fileClass, scanParam.ConvertLimitSize)
}
return &fileClass, nil
}
Expand Down
12 changes: 8 additions & 4 deletions dataviewer/workflows/activity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package workflows
import (
"context"
"encoding/base64"
"math"
"testing"

"github.com/stretchr/testify/mock"
Expand Down Expand Up @@ -91,15 +92,18 @@ func TestActivity_ScanRepoFiles(t *testing.T) {
RepoID: int64(1),
}

mockGitServer.EXPECT().GetRepoFileTree(mock.Anything, gitserver.GetRepoInfoByPathReq{
mockGitServer.EXPECT().GetTree(mock.Anything, types.GetTreeRequest{
Namespace: req.Namespace,
Name: req.Name,
Ref: req.Branch,
RepoType: req.RepoType,
Limit: math.MaxInt,
Recursive: true,
}).Return(
[]*types.File{
{Name: "foobar.parquet", Path: "foo/foobar.parquet"},
}, nil,
&types.GetRepoFileTreeResp{
Files: []*types.File{
{Name: "foobar.parquet", Path: "foo/foobar.parquet"},
}}, nil,
)

dvActivity, err := NewTestDataViewerActivity(config, mockGitServer, s3Client, dvstore)
Expand Down
22 changes: 8 additions & 14 deletions dataviewer/workflows/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ import (
"io"
"log/slog"
"os"
"regexp"
"strings"
"time"

"github.com/bmatcuk/doublestar/v4"
"go.temporal.io/sdk/activity"
"opencsg.com/csghub-server/common/types"
dvCom "opencsg.com/csghub-server/dataviewer/common"
Expand Down Expand Up @@ -51,29 +51,23 @@ func GetPatternFileList(path interface{}) []string {
return files
}

func ConvertRealFiles(splitFiles []string, sortKeys []string, targetFiles map[string]*dvCom.RepoFile, subsetName, splitName string) []dvCom.FileObject {
func ConvertRealFiles(splitFiles []string, filePaths []string, targetFiles map[string]*dvCom.RepoFile, subsetName, splitName string) []dvCom.FileObject {
var phyFiles []dvCom.FileObject
for _, filePattern := range splitFiles {
if !strings.Contains(filePattern, dvCom.WILDCARD) {
if !strings.Contains(filePattern, dvCom.WILDCARD) || !doublestar.ValidatePathPattern(filePattern) {
file, exists := targetFiles[filePattern]
if exists {
phyFiles = append(phyFiles, TransferFileObject(file, subsetName, splitName))
}
continue
}

fileReg, err := regexp.Compile(filePattern)
if err != nil {
slog.Warn("invalid regexp format of split file", slog.Any("filePattern", filePattern), slog.Any("err", err))
file, exists := targetFiles[filePattern]
if exists {
phyFiles = append(phyFiles, TransferFileObject(file, subsetName, splitName))
for _, path := range filePaths {
match, err := doublestar.PathMatch(filePattern, path)
if err != nil {
slog.Error("file pattern match", "error", err)
}
continue
}
for _, path := range sortKeys {
// repo file match like: test/test-*
if fileReg.MatchString(path) {
if match {
file, exists := targetFiles[path]
if exists {
phyFiles = append(phyFiles, TransferFileObject(file, subsetName, splitName))
Expand Down
64 changes: 43 additions & 21 deletions dataviewer/workflows/utils_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package workflows

import (
"fmt"
"io"
"os"
"strings"
Expand All @@ -26,29 +27,50 @@ func TestUtils_GetPatternFileList(t *testing.T) {
}

func TestUtils_ConvertRealFiles(t *testing.T) {
splitFiles := []string{"a/1.parquet", "b/2.parquet"}
sortKeys := []string{"a", "b"}

targetFiles := map[string]*dvCom.RepoFile{
"a/1.parquet": {
File: &types.File{
Path: "a/1.parquet",
},
},
"b/2.parquet": {
File: &types.File{
Path: "b/2.parquet",
},
},
"c/3.parquet": {
File: &types.File{
Path: "c/3.parquet",
},
},
exists := map[string]*dvCom.RepoFile{}
paths := []string{
"foo/a.csv",
"foo/b.csv",
"foo/a.json",
"bar/c.csv",
"bar/d.csv",
"bar/a.json",
"foo/v1/e.csv",
"foo/v2/f.csv",
"foo/v1/t1/g.csv",
}
for _, path := range paths {
exists[path] = &dvCom.RepoFile{File: &types.File{Path: path}}
}
// not exists files
paths = append(paths, "foo/zz.csv")
paths = append(paths, "bar/qq.csv")

cases := []struct {
split string
expected []string
}{
{split: "foobar/a.csv", expected: []string{}},
{split: "foo/a.csv", expected: []string{"foo/a.csv"}},
{split: "foo/*.csv", expected: []string{"foo/a.csv", "foo/b.csv"}},
{split: "bar/*.csv", expected: []string{"bar/c.csv", "bar/d.csv"}},
{split: "foo/**/*.csv", expected: []string{
"foo/a.csv", "foo/b.csv", "foo/v1/e.csv", "foo/v2/f.csv",
"foo/v1/t1/g.csv",
}},
{split: "bar/**/*.csv", expected: []string{"bar/c.csv", "bar/d.csv"}},
}

res := ConvertRealFiles(splitFiles, sortKeys, targetFiles, "default", "train")
require.Equal(t, 2, len(res))
for _, c := range cases {
t.Run(fmt.Sprintf("%+v", c), func(t *testing.T) {
match := ConvertRealFiles([]string{c.split}, paths, exists, "default", "test")
paths := []string{}
for _, f := range match {
paths = append(paths, f.RepoFile)
}
require.Equal(t, c.expected, paths)
})
}
}

func TestUtils_GetCardDataMD5(t *testing.T) {
Expand Down
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ require (
github.com/alibabacloud-go/green-20220302 v1.2.0
github.com/alibabacloud-go/tea v1.2.1
github.com/aliyun/alibaba-cloud-sdk-go v1.62.648
github.com/bmatcuk/doublestar/v4 v4.8.1
github.com/casdoor/casdoor-go-sdk v0.41.0
github.com/chenyahui/gin-cache v1.9.0
github.com/d5/tengo/v2 v2.17.0
Expand Down Expand Up @@ -61,6 +62,7 @@ require (
go.opentelemetry.io/otel/trace v1.33.0
go.temporal.io/api v1.43.0
go.temporal.io/sdk v1.31.0
go.temporal.io/sdk/contrib/opentelemetry v0.6.0
go.temporal.io/server v1.26.2
google.golang.org/grpc v1.68.1
gopkg.in/yaml.v2 v2.4.0
Expand Down
4 changes: 4 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ github.com/blang/semver/v4 v4.0.0 h1:1PFHFE6yCCTv8C1TeyNNarDzntLi7wMI5i/pzqYIsAM
github.com/blang/semver/v4 v4.0.0/go.mod h1:IbckMUScFkM3pff0VJDNKRiT6TG/YpiHIM2yvyW5YoQ=
github.com/blendle/zapdriver v1.3.1 h1:C3dydBOWYRiOk+B8X9IVZ5IOe+7cl+tGOexN4QqHfpE=
github.com/blendle/zapdriver v1.3.1/go.mod h1:mdXfREi6u5MArG4j9fewC+FGnXaBR+T4Ox4J2u4eHCc=
github.com/bmatcuk/doublestar/v4 v4.8.1 h1:54Bopc5c2cAvhLRAzqOGCYHYyhcDHsFF4wWIR5wKP38=
github.com/bmatcuk/doublestar/v4 v4.8.1/go.mod h1:xBQ8jztBU6kakFMg+8WGxn0c6z1fTSPVIjEY1Wr7jzc=
github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 h1:DDGfHa7BWjL4YnC6+E63dPcxHo2sUxDIu8g3QgEJdRY=
github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4=
github.com/bmizerany/perks v0.0.0-20141205001514-d9a9656a3a4b h1:AP/Y7sqYicnjGDfD5VcY4CIfh1hRXBUavxrvELjTiOE=
Expand Down Expand Up @@ -828,6 +830,8 @@ go.temporal.io/api v1.43.0 h1:lBhq+u5qFJqGMXwWsmg/i8qn1UA/3LCwVc88l2xUMHg=
go.temporal.io/api v1.43.0/go.mod h1:1WwYUMo6lao8yl0371xWUm13paHExN5ATYT/B7QtFis=
go.temporal.io/sdk v1.31.0 h1:CLYiP0R5Sdj0gq8LyYKDDz4ccGOdJPR8wNGJU0JGwj8=
go.temporal.io/sdk v1.31.0/go.mod h1:8U8H7rF9u4Hyb4Ry9yiEls5716DHPNvVITPNkgWUwE8=
go.temporal.io/sdk/contrib/opentelemetry v0.6.0 h1:rNBArDj5iTUkcMwKocUShoAW59o6HdS7Nq4CTp4ldj8=
go.temporal.io/sdk/contrib/opentelemetry v0.6.0/go.mod h1:Lem8VrE2ks8P+FYcRM3UphPoBr+tfM3v/Kaf0qStzSg=
go.temporal.io/server v1.26.2 h1:vDW11lxslYPlGDbQklWi/tqbkVZ2ExtRO1jNjvZmUUI=
go.temporal.io/server v1.26.2/go.mod h1:tgY+4z/PuIdqs6ouV1bT90RWSWfEioWkzmrNrLYLUrk=
go.temporal.io/version v0.3.0 h1:dMrei9l9NyHt8nG6EB8vAwDLLTwx2SvRyucCSumAiig=
Expand Down
Loading