Skip to content

Commit 9ebc990

Browse files
authored
feat(go/bigquery): python models(bigframes) (#97)
1 parent ec96225 commit 9ebc990

File tree

5 files changed

+482
-0
lines changed

5 files changed

+482
-0
lines changed

go/adbc/driver/bigquery/connection.go

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,20 @@ import (
2222
"context"
2323
"crypto/tls"
2424
"encoding/json"
25+
"errors"
2526
"fmt"
2627
"io"
2728
"log"
2829
"net/http"
2930
"net/url"
31+
"os"
3032
"regexp"
3133
"strconv"
3234
"strings"
3335
"time"
3436

37+
aiplatform "cloud.google.com/go/aiplatform/apiv1"
38+
"cloud.google.com/go/aiplatform/apiv1/aiplatformpb"
3539
"cloud.google.com/go/bigquery"
3640
dataproc "cloud.google.com/go/dataproc/v2/apiv1"
3741
"cloud.google.com/go/storage"
@@ -43,6 +47,7 @@ import (
4347
"google.golang.org/api/impersonate"
4448
"google.golang.org/api/iterator"
4549
"google.golang.org/api/option"
50+
"google.golang.org/api/transport"
4651
)
4752

4853
type connectionImpl struct {
@@ -760,6 +765,123 @@ func (c *connectionImpl) newGCSClient(ctx context.Context) (*storage.Client, err
760765
return client, nil
761766
}
762767

768+
func (c *connectionImpl) newNotebookClient(ctx context.Context, computeRegion string) (*aiplatform.NotebookClient, error) {
769+
authOptions, err := c.authOptions(ctx)
770+
if err != nil {
771+
return nil, err
772+
}
773+
774+
authOptions = append(authOptions, option.WithEndpoint(fmt.Sprintf("%s-aiplatform.googleapis.com:443", computeRegion)))
775+
776+
client, err := aiplatform.NewNotebookClient(ctx, authOptions...)
777+
if err != nil {
778+
return nil, err
779+
}
780+
781+
return client, nil
782+
}
783+
784+
func (c *connectionImpl) addExecutionIdentitiyDetails(ctx context.Context, job *aiplatformpb.NotebookExecutionJob) (*aiplatformpb.NotebookExecutionJob, error) {
785+
switch c.authType {
786+
case OptionValueAuthTypeJSONCredentialFile:
787+
data, err := os.ReadFile(c.credentials)
788+
if err != nil {
789+
panic(fmt.Errorf("failed to read JSON file: %v", err))
790+
}
791+
792+
var sa struct {
793+
ClientEmail string `json:"client_email"`
794+
}
795+
if err := json.Unmarshal(data, &sa); err != nil {
796+
panic(fmt.Errorf("failed to parse JSON: %v", err))
797+
}
798+
job.ExecutionIdentity = &aiplatformpb.NotebookExecutionJob_ServiceAccount{
799+
ServiceAccount: sa.ClientEmail,
800+
}
801+
return job, nil
802+
case OptionValueAuthTypeJSONCredentialString:
803+
data := []byte(c.credentials)
804+
var sa struct {
805+
ClientEmail string `json:"client_email"`
806+
}
807+
if err := json.Unmarshal(data, &sa); err != nil {
808+
panic(fmt.Errorf("failed to parse JSON string: %v", err))
809+
}
810+
811+
job.ExecutionIdentity = &aiplatformpb.NotebookExecutionJob_ServiceAccount{
812+
ServiceAccount: sa.ClientEmail,
813+
}
814+
return job, nil
815+
case OptionValueAuthTypeDefault,
816+
OptionValueAuthTypeUserAuthentication,
817+
OptionValueAuthTypeTemporaryAccessToken:
818+
if c.impersonateTargetPrincipal != "" {
819+
job.ExecutionIdentity = &aiplatformpb.NotebookExecutionJob_ServiceAccount{
820+
ServiceAccount: c.impersonateTargetPrincipal,
821+
}
822+
} else {
823+
authOptions, err := c.authOptions(ctx)
824+
if err != nil {
825+
return nil, err
826+
}
827+
ts, _, err := transport.NewHTTPClient(ctx, append(authOptions, option.WithScopes("https://www.googleapis.com/auth/userinfo.email"))...)
828+
if err != nil {
829+
panic(err)
830+
}
831+
tokenSource := oauth2.StaticTokenSource(&oauth2.Token{}) // placeholder
832+
if t, ok := ts.Transport.(*oauth2.Transport); ok {
833+
tokenSource = t.Source
834+
}
835+
token, err := tokenSource.Token()
836+
if err != nil {
837+
panic(err)
838+
}
839+
url := "https://www.googleapis.com/oauth2/v2/userinfo"
840+
req, err := http.NewRequest("GET", url, nil)
841+
if err != nil {
842+
panic(err)
843+
}
844+
req.Header.Add("Authorization", "Bearer "+token.AccessToken)
845+
client := &http.Client{}
846+
resp, err := client.Do(req)
847+
if err != nil {
848+
panic(err)
849+
}
850+
defer resp.Body.Close()
851+
body, err := io.ReadAll(resp.Body)
852+
if err != nil {
853+
panic(err)
854+
}
855+
if resp.StatusCode != http.StatusOK {
856+
panic(fmt.Errorf("failed to retrieve user info. Status: %d, Body: %s", resp.StatusCode, string(body)))
857+
}
858+
var data map[string]interface{}
859+
if err := json.Unmarshal(body, &data); err != nil {
860+
panic(err)
861+
}
862+
email, ok := data["email"].(string)
863+
if !ok || email == "" {
864+
panic(errors.New("authorization request to get user failed to return an email"))
865+
}
866+
if strings.HasSuffix(email, "iam.gserviceaccount.com") {
867+
job.ExecutionIdentity = &aiplatformpb.NotebookExecutionJob_ServiceAccount{
868+
ServiceAccount: email,
869+
}
870+
} else {
871+
job.ExecutionIdentity = &aiplatformpb.NotebookExecutionJob_ExecutionUser{
872+
ExecutionUser: email,
873+
}
874+
}
875+
}
876+
return job, nil
877+
default:
878+
return nil, adbc.Error{
879+
Code: adbc.StatusInvalidArgument,
880+
Msg: "Unsupported credential method in BigFrames",
881+
}
882+
}
883+
}
884+
763885
func (c *connectionImpl) hasImpersonationOptions() bool {
764886
return c.impersonateTargetPrincipal != "" ||
765887
len(c.impersonateDelegates) > 0

go/adbc/driver/bigquery/driver.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,15 @@ const (
105105
OptionStringWriteGCSObjectName = "adbc.bigquery.write_gcs.object_name"
106106
OptionStringWriteGCSContent = "adbc.bigquery.write_gcs.content"
107107

108+
OptionStringNotebookExecuteJobGscPath = "adbc.bigquery.notebook_execute_job.gsc_path"
109+
OptionStringNotebookExecuteJobModelFileName = "adbc.bigquery.notebook_execute_job.model_file_name"
110+
OptionStringNotebookExecuteJobModelName = "adbc.bigquery.notebook_execute_job.model_name"
111+
OptionStringNotebookExecuteJobGscBucket = "adbc.bigquery.notebook_execute_job.gsc_bucket"
112+
OptionStringNotebookExecuteJobTemplateId = "adbc.bigquery.notebook_execute_job.template_id"
113+
OptionStringNotebookExecuteJobParent = "adbc.bigquery.notebook_execute_job.parent"
114+
OptionStringNotebookExecuteJobProject = "adbc.bigquery.notebook_execute_job.project"
115+
OptionStringNotebookExecuteJobRegion = "adbc.bigquery.notebook_execute_job.region"
116+
108117
OptionJsonUpdateTableColumnsDescription = "adbc.bigquery.table.update_columns_description"
109118
OptionJsonAuthorizeViewToDatasets = "adbc.bigquery.dataset.authorize_view_to_datasets"
110119

0 commit comments

Comments
 (0)