@@ -22,17 +22,16 @@ import (
2222	"fmt" 
2323	"net/url" 
2424	"strconv" 
25+ 	"strings" 
2526	"time" 
2627
27- 	"github.com/google/uuid" 
28- 
2928	"github.com/dapr/components-contrib/state" 
3029	stateutils "github.com/dapr/components-contrib/state/utils" 
3130	"github.com/dapr/kit/logger" 
3231	"github.com/dapr/kit/metadata" 
3332
34- 	// Blank import for the underlying Oracle Database driver. 
35- 	_  "github.com/sijms/go-ora/v2" 
33+ 	"github.com/google/uuid" 
34+ 	goora  "github.com/sijms/go-ora/v2" 
3635)
3736
3837const  (
@@ -78,23 +77,26 @@ func parseMetadata(meta map[string]string) (oracleDatabaseMetadata, error) {
7877// Init sets up OracleDatabase connection and ensures that the state table exists. 
7978func  (o  * oracleDatabaseAccess ) Init (ctx  context.Context , metadata  state.Metadata ) error  {
8079	meta , err  :=  parseMetadata (metadata .Properties )
81- 	o .metadata  =  meta 
8280	if  err  !=  nil  {
8381		return  err 
8482	}
85- 	if  o .metadata .ConnectionString  !=  ""  {
86- 		o .connectionString  =  meta .ConnectionString 
87- 	} else  {
83+ 
84+ 	o .metadata  =  meta 
85+ 
86+ 	if  o .metadata .ConnectionString  ==  ""  {
8887		o .logger .Error ("Missing Oracle Database connection string" )
8988		return  errors .New (errMissingConnectionString )
9089	}
91- 	if  o .metadata .OracleWalletLocation  !=  ""  {
92- 		o .connectionString  +=  "?TRACE FILE=trace.log&SSL=enable&SSL Verify=false&WALLET="  +  url .QueryEscape (o .metadata .OracleWalletLocation )
90+ 
91+ 	o .connectionString , err  =  parseConnectionString (meta )
92+ 	if  err  !=  nil  {
93+ 		o .logger .Error (err )
94+ 		return  err 
9395	}
96+ 
9497	db , err  :=  sql .Open ("oracle" , o .connectionString )
9598	if  err  !=  nil  {
9699		o .logger .Error (err )
97- 
98100		return  err 
99101	}
100102
@@ -105,12 +107,62 @@ func (o *oracleDatabaseAccess) Init(ctx context.Context, metadata state.Metadata
105107		return  err 
106108	}
107109
108- 	err  =  o .ensureStateTable (o .metadata .TableName )
110+ 	return  o .ensureStateTable (o .metadata .TableName )
111+ }
112+ 
113+ func  parseConnectionString (meta  oracleDatabaseMetadata ) (string , error ) {
114+ 	username  :=  "" 
115+ 	password  :=  "" 
116+ 	host  :=  "" 
117+ 	port  :=  0 
118+ 	serviceName  :=  "" 
119+ 	query  :=  url.Values {}
120+ 	options  :=  make (map [string ]string )
121+ 
122+ 	connectionStringURL , err  :=  url .Parse (meta .ConnectionString )
109123	if  err  !=  nil  {
110- 		return  err 
124+ 		return  "" ,  err 
111125	}
112126
113- 	return  nil 
127+ 	isURL  :=  connectionStringURL .Scheme  !=  ""  &&  connectionStringURL .Host  !=  "" 
128+ 	if  isURL  {
129+ 		username  =  connectionStringURL .User .Username ()
130+ 		password , _  =  connectionStringURL .User .Password ()
131+ 		query  =  connectionStringURL .Query ()
132+ 		serviceName  =  strings .TrimPrefix (connectionStringURL .Path , "/" )
133+ 		if  strings .Contains (connectionStringURL .Host , ":" ) {
134+ 			host  =  strings .Split (connectionStringURL .Host , ":" )[0 ]
135+ 		} else  {
136+ 			host  =  connectionStringURL .Host 
137+ 		}
138+ 	} else  {
139+ 		host  =  connectionStringURL .Path 
140+ 	}
141+ 
142+ 	if  connectionStringURL .Port () !=  ""  {
143+ 		port , err  =  strconv .Atoi (connectionStringURL .Port ())
144+ 		if  err  !=  nil  {
145+ 			return  "" , err 
146+ 		}
147+ 	}
148+ 
149+ 	for  k , v  :=  range  query  {
150+ 		options [k ] =  v [0 ]
151+ 	}
152+ 
153+ 	if  meta .OracleWalletLocation  !=  ""  {
154+ 		options ["WALLET" ] =  meta .OracleWalletLocation 
155+ 		options ["TRACE FILE" ] =  "trace.log" 
156+ 		options ["SSL" ] =  "enable" 
157+ 		options ["SSL Verify" ] =  "false" 
158+ 	}
159+ 
160+ 	if  strings .Contains (host , "(DESCRIPTION" ) {
161+ 		// the connection string is a URL that contains the descriptor and authentication info 
162+ 		return  goora .BuildJDBC (username , password , host , options ), nil 
163+ 	} else  {
164+ 		return  goora .BuildUrl (host , port , serviceName , username , password , options ), nil 
165+ 	}
114166}
115167
116168// Set makes an insert or update to the database. 
@@ -170,7 +222,7 @@ func (o *oracleDatabaseAccess) doSet(ctx context.Context, db querier, req *state
170222		if  req .Options .Concurrency  ==  state .FirstWrite  {
171223			stmt  =  `INSERT INTO `  +  o .metadata .TableName  +  ` 
172224				(key, value, binary_yn, etag, expiration_time) 
173- 			VALUES   
225+ 			VALUES 
174226				(:key, :value, :binary_yn, :etag, `  +  ttlStatement  +  `) ` 
175227		} else  {
176228			// As per Discord Thread https://discord.com/channels/778680217417809931/901141713089863710/938520959562952735 expiration time is reset in case of an update. 
0 commit comments