@@ -22,16 +22,20 @@ import (
2222 "context"
2323 "crypto/tls"
2424 "encoding/json"
25+ "errors"
2526 "fmt"
2627 "io"
2728 "log"
2829 "net/http"
2930 "net/url"
31+ "os"
3032 "regexp"
3133 "strconv"
3234 "strings"
3335 "time"
3436
37+ aiplatform "cloud.google.com/go/aiplatform/apiv1"
38+ "cloud.google.com/go/aiplatform/apiv1/aiplatformpb"
3539 "cloud.google.com/go/bigquery"
3640 dataproc "cloud.google.com/go/dataproc/v2/apiv1"
3741 "cloud.google.com/go/storage"
@@ -43,6 +47,7 @@ import (
4347 "google.golang.org/api/impersonate"
4448 "google.golang.org/api/iterator"
4549 "google.golang.org/api/option"
50+ "google.golang.org/api/transport"
4651)
4752
4853type connectionImpl struct {
@@ -760,6 +765,123 @@ func (c *connectionImpl) newGCSClient(ctx context.Context) (*storage.Client, err
760765 return client , nil
761766}
762767
768+ func (c * connectionImpl ) newNotebookClient (ctx context.Context , computeRegion string ) (* aiplatform.NotebookClient , error ) {
769+ authOptions , err := c .authOptions (ctx )
770+ if err != nil {
771+ return nil , err
772+ }
773+
774+ authOptions = append (authOptions , option .WithEndpoint (fmt .Sprintf ("%s-aiplatform.googleapis.com:443" , computeRegion )))
775+
776+ client , err := aiplatform .NewNotebookClient (ctx , authOptions ... )
777+ if err != nil {
778+ return nil , err
779+ }
780+
781+ return client , nil
782+ }
783+
784+ func (c * connectionImpl ) addExecutionIdentitiyDetails (ctx context.Context , job * aiplatformpb.NotebookExecutionJob ) (* aiplatformpb.NotebookExecutionJob , error ) {
785+ switch c .authType {
786+ case OptionValueAuthTypeJSONCredentialFile :
787+ data , err := os .ReadFile (c .credentials )
788+ if err != nil {
789+ panic (fmt .Errorf ("failed to read JSON file: %v" , err ))
790+ }
791+
792+ var sa struct {
793+ ClientEmail string `json:"client_email"`
794+ }
795+ if err := json .Unmarshal (data , & sa ); err != nil {
796+ panic (fmt .Errorf ("failed to parse JSON: %v" , err ))
797+ }
798+ job .ExecutionIdentity = & aiplatformpb.NotebookExecutionJob_ServiceAccount {
799+ ServiceAccount : sa .ClientEmail ,
800+ }
801+ return job , nil
802+ case OptionValueAuthTypeJSONCredentialString :
803+ data := []byte (c .credentials )
804+ var sa struct {
805+ ClientEmail string `json:"client_email"`
806+ }
807+ if err := json .Unmarshal (data , & sa ); err != nil {
808+ panic (fmt .Errorf ("failed to parse JSON string: %v" , err ))
809+ }
810+
811+ job .ExecutionIdentity = & aiplatformpb.NotebookExecutionJob_ServiceAccount {
812+ ServiceAccount : sa .ClientEmail ,
813+ }
814+ return job , nil
815+ case OptionValueAuthTypeDefault ,
816+ OptionValueAuthTypeUserAuthentication ,
817+ OptionValueAuthTypeTemporaryAccessToken :
818+ if c .impersonateTargetPrincipal != "" {
819+ job .ExecutionIdentity = & aiplatformpb.NotebookExecutionJob_ServiceAccount {
820+ ServiceAccount : c .impersonateTargetPrincipal ,
821+ }
822+ } else {
823+ authOptions , err := c .authOptions (ctx )
824+ if err != nil {
825+ return nil , err
826+ }
827+ ts , _ , err := transport .NewHTTPClient (ctx , append (authOptions , option .WithScopes ("https://www.googleapis.com/auth/userinfo.email" ))... )
828+ if err != nil {
829+ panic (err )
830+ }
831+ tokenSource := oauth2 .StaticTokenSource (& oauth2.Token {}) // placeholder
832+ if t , ok := ts .Transport .(* oauth2.Transport ); ok {
833+ tokenSource = t .Source
834+ }
835+ token , err := tokenSource .Token ()
836+ if err != nil {
837+ panic (err )
838+ }
839+ url := "https://www.googleapis.com/oauth2/v2/userinfo"
840+ req , err := http .NewRequest ("GET" , url , nil )
841+ if err != nil {
842+ panic (err )
843+ }
844+ req .Header .Add ("Authorization" , "Bearer " + token .AccessToken )
845+ client := & http.Client {}
846+ resp , err := client .Do (req )
847+ if err != nil {
848+ panic (err )
849+ }
850+ defer resp .Body .Close ()
851+ body , err := io .ReadAll (resp .Body )
852+ if err != nil {
853+ panic (err )
854+ }
855+ if resp .StatusCode != http .StatusOK {
856+ panic (fmt .Errorf ("failed to retrieve user info. Status: %d, Body: %s" , resp .StatusCode , string (body )))
857+ }
858+ var data map [string ]interface {}
859+ if err := json .Unmarshal (body , & data ); err != nil {
860+ panic (err )
861+ }
862+ email , ok := data ["email" ].(string )
863+ if ! ok || email == "" {
864+ panic (errors .New ("authorization request to get user failed to return an email" ))
865+ }
866+ if strings .HasSuffix (email , "iam.gserviceaccount.com" ) {
867+ job .ExecutionIdentity = & aiplatformpb.NotebookExecutionJob_ServiceAccount {
868+ ServiceAccount : email ,
869+ }
870+ } else {
871+ job .ExecutionIdentity = & aiplatformpb.NotebookExecutionJob_ExecutionUser {
872+ ExecutionUser : email ,
873+ }
874+ }
875+ }
876+ return job , nil
877+ default :
878+ return nil , adbc.Error {
879+ Code : adbc .StatusInvalidArgument ,
880+ Msg : "Unsupported credential method in BigFrames" ,
881+ }
882+ }
883+ }
884+
763885func (c * connectionImpl ) hasImpersonationOptions () bool {
764886 return c .impersonateTargetPrincipal != "" ||
765887 len (c .impersonateDelegates ) > 0
0 commit comments