Skip to content

Commit 5cf5cf0

Browse files
yiling.jiYiling-J
authored andcommitted
Merge branch 'refactor/dataset_viewer_cat' into 'main'
Dataset viewer workflow refactor and bug fix See merge request product/starhub/starhub-server!873
1 parent f08c508 commit 5cf5cf0

File tree

11 files changed

+121
-79
lines changed

11 files changed

+121
-79
lines changed

api/workflow/worker_ce.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,12 @@ func StartWorkflow(cfg *config.Config) error {
3838
if err != nil {
3939
return err
4040
}
41-
temporalClient, err := client.Dial(client.Options{
41+
client, err := temporal.NewClient(client.Options{
4242
HostPort: cfg.WorkFLow.Endpoint,
43-
})
43+
}, "csghub-api")
4444
if err != nil {
4545
return fmt.Errorf("unable to create workflow client, error: %w", err)
4646
}
47-
client, err := temporal.NewClient(temporalClient)
4847
if err != nil {
4948
return err
5049
}

builder/temporal/temporal.go

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@ package temporal
22

33
import (
44
"context"
5-
"log/slog"
5+
"fmt"
66

7+
"go.opentelemetry.io/otel"
78
"go.temporal.io/sdk/client"
8-
"go.temporal.io/sdk/log"
9+
"go.temporal.io/sdk/contrib/opentelemetry"
10+
"go.temporal.io/sdk/interceptor"
911
"go.temporal.io/sdk/worker"
1012
)
1113

@@ -28,25 +30,30 @@ type clientImpl struct {
2830

2931
var _client Client = &clientImpl{}
3032

31-
func NewClient(temporalClient client.Client) (*clientImpl, error) {
33+
func NewClient(options client.Options, serviceName string) (*clientImpl, error) {
34+
tracingInterceptor, err := opentelemetry.NewTracingInterceptor(opentelemetry.TracerOptions{
35+
Tracer: otel.Tracer(serviceName),
36+
})
37+
if err != nil {
38+
return nil, fmt.Errorf("temporal otel interceptor %w", err)
39+
}
40+
options.Interceptors = []interceptor.ClientInterceptor{tracingInterceptor}
41+
42+
t, err := client.Dial(options)
43+
if err != nil {
44+
return nil, err
45+
}
3246
c := _client.(*clientImpl)
33-
c.Client = temporalClient
47+
c.Client = t
3448

3549
return c, nil
3650
}
3751

38-
func NewClientByHostPort(hostPort string) (*clientImpl, error) {
39-
logger := log.NewStructuredLogger(slog.Default())
40-
temporalClient, err := client.Dial(client.Options{
41-
HostPort: hostPort,
42-
Logger: logger,
43-
})
44-
if err != nil {
45-
return nil, err
46-
}
52+
// used in test only
53+
func Assign(temporalClient client.Client) {
4754
c := _client.(*clientImpl)
4855
c.Client = temporalClient
49-
return c, nil
56+
5057
}
5158

5259
func (c *clientImpl) NewWorker(queue string, options worker.Options) worker.Registry {

builder/temporal/temporal_test.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,14 @@ func TestTemporalClient(t *testing.T) {
3434
c := ts.GetDefaultClient()
3535

3636
tester := &Tester{client: temporal.GetClient()}
37-
_, err := temporal.NewClient(c)
38-
require.NoError(t, err)
37+
temporal.Assign(c)
3938

4039
worker1 := tester.client.NewWorker("q1", worker.Options{})
4140
worker1.RegisterWorkflow(tester.Count)
4241
worker2 := tester.client.NewWorker("q2", worker.Options{})
4342
worker2.RegisterWorkflow(tester.Add)
4443

45-
err = tester.client.Start()
44+
err := tester.client.Start()
4645
require.NoError(t, err)
4746

4847
r, err := tester.client.ExecuteWorkflow(context.TODO(), client.StartWorkflowOptions{

cmd/csghub-server/cmd/dataviewer/launch.go

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"log/slog"
77

88
"github.com/spf13/cobra"
9+
"go.temporal.io/sdk/client"
910
"opencsg.com/csghub-server/api/httpbase"
1011
"opencsg.com/csghub-server/builder/instrumentation"
1112
"opencsg.com/csghub-server/builder/store/database"
@@ -34,17 +35,18 @@ var launchCmd = &cobra.Command{
3435
}
3536
database.InitDB(dbConfig)
3637

37-
tc, err := temporal.NewClientByHostPort(cfg.WorkFLow.Endpoint)
38-
if err != nil {
39-
return fmt.Errorf("build workflow client, error: %w", err)
40-
}
41-
4238
stopOtel, err := instrumentation.SetupOTelSDK(context.Background(), cfg, "dataviewer-api")
4339
if err != nil {
4440
panic(err)
4541
}
4642

47-
r, err := router.NewDataViewerRouter(cfg, tc)
43+
client, err := temporal.NewClient(client.Options{
44+
HostPort: cfg.WorkFLow.Endpoint,
45+
}, "dataset-viewer")
46+
if err != nil {
47+
return fmt.Errorf("unable to create workflow client, error: %w", err)
48+
}
49+
r, err := router.NewDataViewerRouter(cfg, client)
4850
if err != nil {
4951
return fmt.Errorf("failed to init dataviewer router: %w", err)
5052
}

dataviewer/component/callback.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,10 @@ func (c *callbackComponentImpl) TriggerDataviewUpdateWorkflow(ctx context.Contex
100100
WorkflowExecutionTimeout: executeTimeOut,
101101
WorkflowTaskTimeout: taskTimeout,
102102
}
103-
wfRun, err := c.workflowClient.ExecuteWorkflow(ctx, options, workflows.DataViewerUpdateWorkflow,
104-
dvCom.WorkflowUpdateParams{Req: req, Config: c.cfg})
103+
wfRun, err := c.workflowClient.ExecuteWorkflow(
104+
context.Background(), options, workflows.DataViewerUpdateWorkflow,
105+
dvCom.WorkflowUpdateParams{Req: req, Config: c.cfg},
106+
)
105107
if err != nil {
106108
return nil, fmt.Errorf("fail to execute workflow, error: %w", err)
107109
}

dataviewer/workflows/activity.go

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"fmt"
88
"io"
99
"log/slog"
10+
"math"
1011
"net/http"
1112
"net/url"
1213
"os"
@@ -99,7 +100,7 @@ func (dva *dataViewerActivityImpl) GetCardFromReadme(ctx context.Context, req ty
99100
Path: types.REPOCARD_FILENAME,
100101
RepoType: req.RepoType,
101102
}
102-
f, err := dva.gitServer.GetRepoFileContents(context.Background(), fileReq)
103+
f, err := dva.gitServer.GetRepoFileContents(ctx, fileReq)
103104
if err != nil {
104105
slog.Warn("get repo branch readme.md content error", slog.Any("fileReq", fileReq), slog.Any("err", err))
105106
return &card, nil
@@ -127,15 +128,6 @@ func (dva *dataViewerActivityImpl) GetCardFromReadme(ctx context.Context, req ty
127128
}
128129

129130
func (dva *dataViewerActivityImpl) ScanRepoFiles(ctx context.Context, scanParam dvCom.ScanRepoFileReq) (*dvCom.RepoFilesClass, error) {
130-
repoReq := dvCom.RepoFilesReq{
131-
Namespace: scanParam.Req.Namespace,
132-
RepoName: scanParam.Req.Name,
133-
RepoType: scanParam.Req.RepoType,
134-
Ref: scanParam.Req.Branch,
135-
Folder: "",
136-
GSTree: dva.gitServer.GetRepoFileTree,
137-
TotalLimitSize: scanParam.ConvertLimitSize,
138-
}
139131
fileClass := dvCom.RepoFilesClass{
140132
AllFiles: make(map[string]*dvCom.RepoFile),
141133
ParquetFiles: make(map[string]*dvCom.RepoFile),
@@ -144,9 +136,24 @@ func (dva *dataViewerActivityImpl) ScanRepoFiles(ctx context.Context, scanParam
144136
TotalJsonSize: 0,
145137
TotalCsvSize: 0,
146138
}
147-
err := GetFilePaths(repoReq, &fileClass)
139+
140+
resp, err := dva.gitServer.GetTree(ctx, types.GetTreeRequest{
141+
Namespace: scanParam.Req.Namespace,
142+
Name: scanParam.Req.Name,
143+
RepoType: scanParam.Req.RepoType,
144+
Ref: scanParam.Req.Branch,
145+
Recursive: true,
146+
Limit: math.MaxInt,
147+
})
148148
if err != nil {
149-
return nil, fmt.Errorf("scan repo file error: %w", err)
149+
return nil, err
150+
}
151+
152+
for _, file := range resp.Files {
153+
if file.Type == "dir" {
154+
continue
155+
}
156+
appendFile(file, &fileClass, scanParam.ConvertLimitSize)
150157
}
151158
return &fileClass, nil
152159
}

dataviewer/workflows/activity_test.go

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package workflows
33
import (
44
"context"
55
"encoding/base64"
6+
"math"
67
"testing"
78

89
"github.com/stretchr/testify/mock"
@@ -91,15 +92,18 @@ func TestActivity_ScanRepoFiles(t *testing.T) {
9192
RepoID: int64(1),
9293
}
9394

94-
mockGitServer.EXPECT().GetRepoFileTree(mock.Anything, gitserver.GetRepoInfoByPathReq{
95+
mockGitServer.EXPECT().GetTree(mock.Anything, types.GetTreeRequest{
9596
Namespace: req.Namespace,
9697
Name: req.Name,
9798
Ref: req.Branch,
9899
RepoType: req.RepoType,
100+
Limit: math.MaxInt,
101+
Recursive: true,
99102
}).Return(
100-
[]*types.File{
101-
{Name: "foobar.parquet", Path: "foo/foobar.parquet"},
102-
}, nil,
103+
&types.GetRepoFileTreeResp{
104+
Files: []*types.File{
105+
{Name: "foobar.parquet", Path: "foo/foobar.parquet"},
106+
}}, nil,
103107
)
104108

105109
dvActivity, err := NewTestDataViewerActivity(config, mockGitServer, s3Client, dvstore)

dataviewer/workflows/utils.go

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@ import (
1010
"io"
1111
"log/slog"
1212
"os"
13-
"regexp"
1413
"strings"
1514
"time"
1615

16+
"github.com/bmatcuk/doublestar/v4"
1717
"go.temporal.io/sdk/activity"
1818
"opencsg.com/csghub-server/common/types"
1919
dvCom "opencsg.com/csghub-server/dataviewer/common"
@@ -51,29 +51,23 @@ func GetPatternFileList(path interface{}) []string {
5151
return files
5252
}
5353

54-
func ConvertRealFiles(splitFiles []string, sortKeys []string, targetFiles map[string]*dvCom.RepoFile, subsetName, splitName string) []dvCom.FileObject {
54+
func ConvertRealFiles(splitFiles []string, filePaths []string, targetFiles map[string]*dvCom.RepoFile, subsetName, splitName string) []dvCom.FileObject {
5555
var phyFiles []dvCom.FileObject
5656
for _, filePattern := range splitFiles {
57-
if !strings.Contains(filePattern, dvCom.WILDCARD) {
57+
if !strings.Contains(filePattern, dvCom.WILDCARD) || !doublestar.ValidatePathPattern(filePattern) {
5858
file, exists := targetFiles[filePattern]
5959
if exists {
6060
phyFiles = append(phyFiles, TransferFileObject(file, subsetName, splitName))
6161
}
6262
continue
6363
}
6464

65-
fileReg, err := regexp.Compile(filePattern)
66-
if err != nil {
67-
slog.Warn("invalid regexp format of split file", slog.Any("filePattern", filePattern), slog.Any("err", err))
68-
file, exists := targetFiles[filePattern]
69-
if exists {
70-
phyFiles = append(phyFiles, TransferFileObject(file, subsetName, splitName))
65+
for _, path := range filePaths {
66+
match, err := doublestar.PathMatch(filePattern, path)
67+
if err != nil {
68+
slog.Error("file pattern match", "error", err)
7169
}
72-
continue
73-
}
74-
for _, path := range sortKeys {
75-
// repo file match like: test/test-*
76-
if fileReg.MatchString(path) {
70+
if match {
7771
file, exists := targetFiles[path]
7872
if exists {
7973
phyFiles = append(phyFiles, TransferFileObject(file, subsetName, splitName))

dataviewer/workflows/utils_test.go

Lines changed: 43 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package workflows
22

33
import (
4+
"fmt"
45
"io"
56
"os"
67
"strings"
@@ -26,29 +27,50 @@ func TestUtils_GetPatternFileList(t *testing.T) {
2627
}
2728

2829
func TestUtils_ConvertRealFiles(t *testing.T) {
29-
splitFiles := []string{"a/1.parquet", "b/2.parquet"}
30-
sortKeys := []string{"a", "b"}
31-
32-
targetFiles := map[string]*dvCom.RepoFile{
33-
"a/1.parquet": {
34-
File: &types.File{
35-
Path: "a/1.parquet",
36-
},
37-
},
38-
"b/2.parquet": {
39-
File: &types.File{
40-
Path: "b/2.parquet",
41-
},
42-
},
43-
"c/3.parquet": {
44-
File: &types.File{
45-
Path: "c/3.parquet",
46-
},
47-
},
30+
exists := map[string]*dvCom.RepoFile{}
31+
paths := []string{
32+
"foo/a.csv",
33+
"foo/b.csv",
34+
"foo/a.json",
35+
"bar/c.csv",
36+
"bar/d.csv",
37+
"bar/a.json",
38+
"foo/v1/e.csv",
39+
"foo/v2/f.csv",
40+
"foo/v1/t1/g.csv",
41+
}
42+
for _, path := range paths {
43+
exists[path] = &dvCom.RepoFile{File: &types.File{Path: path}}
44+
}
45+
// not exists files
46+
paths = append(paths, "foo/zz.csv")
47+
paths = append(paths, "bar/qq.csv")
48+
49+
cases := []struct {
50+
split string
51+
expected []string
52+
}{
53+
{split: "foobar/a.csv", expected: []string{}},
54+
{split: "foo/a.csv", expected: []string{"foo/a.csv"}},
55+
{split: "foo/*.csv", expected: []string{"foo/a.csv", "foo/b.csv"}},
56+
{split: "bar/*.csv", expected: []string{"bar/c.csv", "bar/d.csv"}},
57+
{split: "foo/**/*.csv", expected: []string{
58+
"foo/a.csv", "foo/b.csv", "foo/v1/e.csv", "foo/v2/f.csv",
59+
"foo/v1/t1/g.csv",
60+
}},
61+
{split: "bar/**/*.csv", expected: []string{"bar/c.csv", "bar/d.csv"}},
4862
}
4963

50-
res := ConvertRealFiles(splitFiles, sortKeys, targetFiles, "default", "train")
51-
require.Equal(t, 2, len(res))
64+
for _, c := range cases {
65+
t.Run(fmt.Sprintf("%+v", c), func(t *testing.T) {
66+
match := ConvertRealFiles([]string{c.split}, paths, exists, "default", "test")
67+
paths := []string{}
68+
for _, f := range match {
69+
paths = append(paths, f.RepoFile)
70+
}
71+
require.Equal(t, c.expected, paths)
72+
})
73+
}
5274
}
5375

5476
func TestUtils_GetCardDataMD5(t *testing.T) {

go.mod

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ require (
99
github.com/alibabacloud-go/green-20220302 v1.2.0
1010
github.com/alibabacloud-go/tea v1.2.1
1111
github.com/aliyun/alibaba-cloud-sdk-go v1.62.648
12+
github.com/bmatcuk/doublestar/v4 v4.8.1
1213
github.com/casdoor/casdoor-go-sdk v0.41.0
1314
github.com/chenyahui/gin-cache v1.9.0
1415
github.com/d5/tengo/v2 v2.17.0
@@ -61,6 +62,7 @@ require (
6162
go.opentelemetry.io/otel/trace v1.33.0
6263
go.temporal.io/api v1.43.0
6364
go.temporal.io/sdk v1.31.0
65+
go.temporal.io/sdk/contrib/opentelemetry v0.6.0
6466
go.temporal.io/server v1.26.2
6567
google.golang.org/grpc v1.68.1
6668
gopkg.in/yaml.v2 v2.4.0

0 commit comments

Comments
 (0)