@@ -6,50 +6,19 @@ import { OAuthScopes, scopeDelimiter } from './OAuthScope';
66import IClientContext from '../../../contracts/IClientContext' ;
77import AuthenticationError from '../../../errors/AuthenticationError' ;
88
9+ export type DefaultOpenAuthUrlCallback = ( authUrl : string ) => Promise < void > ;
10+
11+ export type OpenAuthUrlCallback = ( authUrl : string , defaultOpenAuthUrl : DefaultOpenAuthUrlCallback ) => Promise < void > ;
12+
913export interface AuthorizationCodeOptions {
1014 client : BaseClient ;
1115 ports : Array < number > ;
1216 context : IClientContext ;
17+ openAuthUrl ?: OpenAuthUrlCallback ;
1318}
1419
15- async function startServer (
16- host : string ,
17- port : number ,
18- requestHandler : ( req : IncomingMessage , res : ServerResponse ) => void ,
19- ) : Promise < Server > {
20- const server = http . createServer ( requestHandler ) ;
21-
22- return new Promise ( ( resolve , reject ) => {
23- const errorListener = ( error : Error ) => {
24- server . off ( 'error' , errorListener ) ;
25- reject ( error ) ;
26- } ;
27-
28- server . on ( 'error' , errorListener ) ;
29- server . listen ( port , host , ( ) => {
30- server . off ( 'error' , errorListener ) ;
31- resolve ( server ) ;
32- } ) ;
33- } ) ;
34- }
35-
36- async function stopServer ( server : Server ) : Promise < void > {
37- if ( ! server . listening ) {
38- return ;
39- }
40-
41- return new Promise ( ( resolve , reject ) => {
42- const errorListener = ( error : Error ) => {
43- server . off ( 'error' , errorListener ) ;
44- reject ( error ) ;
45- } ;
46-
47- server . on ( 'error' , errorListener ) ;
48- server . close ( ( ) => {
49- server . off ( 'error' , errorListener ) ;
50- resolve ( ) ;
51- } ) ;
52- } ) ;
20+ async function defaultOpenAuthUrl ( authUrl : string ) : Promise < void > {
21+ await open ( authUrl ) ;
5322}
5423
5524export interface AuthorizationCodeFetchResult {
@@ -65,16 +34,12 @@ export default class AuthorizationCode {
6534
6635 private readonly host : string = 'localhost' ;
6736
68- private readonly ports : Array < number > ;
37+ private readonly options : AuthorizationCodeOptions ;
6938
7039 constructor ( options : AuthorizationCodeOptions ) {
7140 this . client = options . client ;
72- this . ports = options . ports ;
7341 this . context = options . context ;
74- }
75-
76- private async openUrl ( url : string ) {
77- return open ( url ) ;
42+ this . options = options ;
7843 }
7944
8045 public async fetch ( scopes : OAuthScopes ) : Promise < AuthorizationCodeFetchResult > {
@@ -84,7 +49,7 @@ export default class AuthorizationCode {
8449
8550 let receivedParams : CallbackParamsType | undefined ;
8651
87- const server = await this . startServer ( ( req , res ) => {
52+ const server = await this . createServer ( ( req , res ) => {
8853 const params = this . client . callbackParams ( req ) ;
8954 if ( params . state === state ) {
9055 receivedParams = params ;
@@ -108,7 +73,8 @@ export default class AuthorizationCode {
10873 redirect_uri : redirectUri ,
10974 } ) ;
11075
111- await this . openUrl ( authUrl ) ;
76+ const openAuthUrl = this . options . openAuthUrl ?? defaultOpenAuthUrl ;
77+ await openAuthUrl ( authUrl , defaultOpenAuthUrl ) ;
11278 await server . stopped ( ) ;
11379
11480 if ( ! receivedParams || ! receivedParams . code ) {
@@ -122,11 +88,11 @@ export default class AuthorizationCode {
12288 return { code : receivedParams . code , verifier : verifierString , redirectUri } ;
12389 }
12490
125- private async startServer ( requestHandler : ( req : IncomingMessage , res : ServerResponse ) => void ) {
126- for ( const port of this . ports ) {
91+ private async createServer ( requestHandler : ( req : IncomingMessage , res : ServerResponse ) => void ) {
92+ for ( const port of this . options . ports ) {
12793 const host = this . host ; // eslint-disable-line prefer-destructuring
12894 try {
129- const server = await startServer ( host , port , requestHandler ) ; // eslint-disable-line no-await-in-loop
95+ const server = await this . startServer ( host , port , requestHandler ) ; // eslint-disable-line no-await-in-loop
13096 this . context . getLogger ( ) . log ( LogLevel . info , `Listening for OAuth authorization callback at ${ host } :${ port } ` ) ;
13197
13298 let resolveStopped : ( ) => void ;
@@ -140,7 +106,7 @@ export default class AuthorizationCode {
140106 host,
141107 port,
142108 server,
143- stop : ( ) => stopServer ( server ) . then ( resolveStopped ) . catch ( rejectStopped ) ,
109+ stop : ( ) => this . stopServer ( server ) . then ( resolveStopped ) . catch ( rejectStopped ) ,
144110 stopped : ( ) => stoppedPromise ,
145111 } ;
146112 } catch ( error ) {
@@ -156,6 +122,50 @@ export default class AuthorizationCode {
156122 throw new AuthenticationError ( 'Failed to start server: all ports are in use' ) ;
157123 }
158124
125+ private createHttpServer ( requestHandler : ( req : IncomingMessage , res : ServerResponse ) => void ) {
126+ return http . createServer ( requestHandler ) ;
127+ }
128+
129+ private async startServer (
130+ host : string ,
131+ port : number ,
132+ requestHandler : ( req : IncomingMessage , res : ServerResponse ) => void ,
133+ ) : Promise < Server > {
134+ const server = this . createHttpServer ( requestHandler ) ;
135+
136+ return new Promise ( ( resolve , reject ) => {
137+ const errorListener = ( error : Error ) => {
138+ server . off ( 'error' , errorListener ) ;
139+ reject ( error ) ;
140+ } ;
141+
142+ server . on ( 'error' , errorListener ) ;
143+ server . listen ( port , host , ( ) => {
144+ server . off ( 'error' , errorListener ) ;
145+ resolve ( server ) ;
146+ } ) ;
147+ } ) ;
148+ }
149+
150+ private async stopServer ( server : Server ) : Promise < void > {
151+ if ( ! server . listening ) {
152+ return ;
153+ }
154+
155+ return new Promise ( ( resolve , reject ) => {
156+ const errorListener = ( error : Error ) => {
157+ server . off ( 'error' , errorListener ) ;
158+ reject ( error ) ;
159+ } ;
160+
161+ server . on ( 'error' , errorListener ) ;
162+ server . close ( ( ) => {
163+ server . off ( 'error' , errorListener ) ;
164+ resolve ( ) ;
165+ } ) ;
166+ } ) ;
167+ }
168+
159169 private renderCallbackResponse ( ) : string {
160170 const applicationName = 'Databricks Sql Connector' ;
161171
0 commit comments