diff --git a/config.go b/config.go index 0aa9890..e127477 100644 --- a/config.go +++ b/config.go @@ -35,6 +35,7 @@ package main import ( "bufio" + "context" "crypto" "crypto/ecdsa" "crypto/rsa" @@ -52,6 +53,8 @@ import ( "strings" "time" + "github.com/Azure/go-ntlmssp" + "github.com/dpotapov/go-spnego" "github.com/moriyoshi/mimetypes" "github.com/pkg/errors" "gopkg.in/yaml.v2" @@ -88,9 +91,24 @@ type MITMConfig struct { DisableCache bool } +type DirectSetter func(*http.Request) error + +type UserPasswordPair struct { + User string + Password string +} + +type RoundTripperWrapper func(http.RoundTripper, *http.Request) (*http.Response, error) + +type AuthConfig struct { + RoundTripperWrapperFactory func() (RoundTripperWrapper, error) + CredentialsProvider func(context.Context) (interface{}, error) +} + type ProxyConfig struct { HTTPProxy *url.URL HTTPSProxy *url.URL + Auth AuthConfig IncludedHosts []HostPortPair ExcludedHosts []HostPortPair TLSConfig *tls.Config @@ -257,6 +275,75 @@ func parseUrlOrHostPortPair(urlOrHostPortPair string) (retval *url.URL, err erro return } +func (ctx *ConfigReaderContext) lookupRoundTripperWrapperFactory(typ string) (func() (RoundTripperWrapper, error), error) { + switch typ { + case "ntlm", "ntlm-ssp": + return func() (RoundTripperWrapper, error) { + return func(rt http.RoundTripper, req *http.Request) (*http.Response, error) { + return ntlmssp.Negotiator{rt}.RoundTrip(req) + }, nil + }, nil + case "gssapi", "spnego": + return func() (RoundTripperWrapper, error) { + p := spnego.New() + return func(rt http.RoundTripper, req *http.Request) (*http.Response, error) { + err := p.SetSPNEGOHeader(req) + if err != nil { + return nil, err + } + return rt.RoundTrip(req) + }, nil + }, nil + default: + return nil, errors.Errorf("unknown roundtripper wrapper: %s", typ) + } +} + +func (ctx *ConfigReaderContext) extractAuthConfig(deref dereference) (retval AuthConfig, err error) { + err = deref.multi( + "type", func(typ string) error { + var err error + retval.RoundTripperWrapperFactory, err = ctx.lookupRoundTripperWrapperFactory(typ) + return err + }, + "credentials", func(deref dereference) error { + var upp *UserPasswordPair + err := deref.multi( + "user", func(v string) error { + if upp == nil { + upp = &UserPasswordPair{} + } + upp.User = v + return nil + }, + "password", func(v string) error { + if upp == nil { + upp = &UserPasswordPair{} + } + upp.Password = v + return nil + }, + ) + if err != nil { + return err + } + if upp != nil { + retval.CredentialsProvider = func(_ context.Context) (interface{}, error) { + return upp, nil + } + } else { + retval.CredentialsProvider = func(_ context.Context) (interface{}, error) { + return func(_ *http.Request) error { + return nil + }, nil + } + } + return nil + }, + ) + return +} + func (ctx *ConfigReaderContext) extractProxyConfig(deref dereference) (retval ProxyConfig, err error) { err = deref.multi( "proxy", func(deref dereference) error { @@ -279,6 +366,11 @@ func (ctx *ConfigReaderContext) extractProxyConfig(deref dereference) (retval Pr retval.HTTPSProxy = httpsProxyUrl return nil }, + "auth", func(deref dereference) error { + var err error + retval.Auth, err = ctx.extractAuthConfig(deref) + return err + }, "included", func(includedHosts []string) error { retval.IncludedHosts, err = convertUnparsedHostsIntoPairs(includedHosts) if err != nil { diff --git a/example.yml b/example.yml index 1113b34..47ae406 100644 --- a/example.yml +++ b/example.yml @@ -2,6 +2,11 @@ proxy: http: http://127.0.0.1:9080/ https: http://127.0.0.1:9080/ + auth: + type: ntlm + credentials: + user: DOMAIN\\FOO + password: PASS excluded: - localhost:8081 - localhost:8082 diff --git a/go.mod b/go.mod index 322430d..d75093f 100644 --- a/go.mod +++ b/go.mod @@ -1,14 +1,15 @@ module github.com/moriyoshi/devproxy require ( + github.com/Azure/go-ntlmssp v0.0.0-20200615164410-66371956d46c + github.com/dpotapov/go-spnego v0.0.0-20190506202455-c2c609116ad0 github.com/moriyoshi/mimetypes v1.0.0 github.com/moriyoshi/simplefiletx v1.0.0 github.com/pkg/errors v0.9.1 github.com/shibukawa/configdir v0.0.0-20170330084843-e180dbdc8da0 github.com/sirupsen/logrus v1.3.0 github.com/stretchr/testify v1.2.2 - golang.org/x/net v0.0.0-20190125091013-d26f9f9a57f3 - golang.org/x/text v0.3.0 // indirect + golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3 gopkg.in/yaml.v2 v2.3.0 ) diff --git a/go.sum b/go.sum index ec74ab5..d383946 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,13 @@ +github.com/Azure/go-ntlmssp v0.0.0-20200615164410-66371956d46c h1:/IBSNwUN8+eKzUzbJPqhK839ygXJ82sde8x3ogr6R28= +github.com/Azure/go-ntlmssp v0.0.0-20200615164410-66371956d46c/go.mod h1:chxPXzSsl7ZWRAuOIE23GDNzjWuZquvFlgA8xmpunjU= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dpotapov/go-spnego v0.0.0-20190506202455-c2c609116ad0 h1:Hhh7nu7CfFVlnBJqmDDUh+j1H5fqjLMzM4czZzNNJGM= +github.com/dpotapov/go-spnego v0.0.0-20190506202455-c2c609116ad0/go.mod h1:P4f4MSk7h52F2PK0lCapn5+fu47Uf8aRdxDSqgezxZE= +github.com/hashicorp/go-uuid v1.0.1 h1:fv1ep09latC32wFoVwnqcnKJGnMSdBanPczbHAYm1BE= +github.com/hashicorp/go-uuid v1.0.1/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= +github.com/jcmturner/gofork v0.0.0-20190328161633-dc7c13fece03 h1:FUwcHNlEqkqLjLBdCp5PRlCFijNjvcYANOZXzCfXwCM= +github.com/jcmturner/gofork v0.0.0-20190328161633-dc7c13fece03/go.mod h1:MK8+TM0La+2rjBD4jE12Kj1pCCxK7d2LK/UM3ncEo0o= github.com/konsorten/go-windows-terminal-sequences v1.0.1 h1:mweAR1A6xJ3oS2pRaGiHgQ4OO8tzTaLawm8vnODuwDk= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/moriyoshi/mimetypes v1.0.0 h1:nESQmdWurua/+7QzWnxMpbHYUd0mMsi6zKAkq3ZbU50= @@ -19,13 +27,29 @@ github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1 github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793 h1:u+LnwYTOOW7Ukr/fppxEb1Nwz0AtPflrblfvUudpo+I= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190426145343-a29dc8fdc734 h1:p/H982KKEjUnLJkM3tt/LemDnOc1GiZL5FCVlORJ5zo= +golang.org/x/crypto v0.0.0-20190426145343-a29dc8fdc734/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/net v0.0.0-20190125091013-d26f9f9a57f3 h1:ulvT7fqt0yHWzpJwI57MezWnYDVpCAYBVuYst/L+fAY= golang.org/x/net v0.0.0-20190125091013-d26f9f9a57f3/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3 h1:0GoQqolDA55aaLxZyTzK/Y2ePZzZTUrRacwib7cNsYQ= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33 h1:I6FyU15t786LL7oL/hn43zqTuEGr4PN7F4XJ1p4E3Y8= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d h1:+R4KGOnez64A81RvjARKc4UT5/tI9ujCIVX+P5KiHuI= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/jcmturner/aescts.v1 v1.0.1 h1:cVVZBK2b1zY26haWB4vbBiZrfFQnfbTVrE3xZq6hrEw= +gopkg.in/jcmturner/aescts.v1 v1.0.1/go.mod h1:nsR8qBOg+OucoIW+WMhB3GspUQXq9XorLnQb9XtvcOo= +gopkg.in/jcmturner/dnsutils.v1 v1.0.1 h1:cIuC1OLRGZrld+16ZJvvZxVJeKPsvd5eUIvxfoN5hSM= +gopkg.in/jcmturner/dnsutils.v1 v1.0.1/go.mod h1:m3v+5svpVOhtFAP/wSz+yzh4Mc0Fg7eRhxkJMWSIz9Q= +gopkg.in/jcmturner/gokrb5.v5 v5.3.0 h1:RS1MYApX27Hx1Xw7NECs7XxGxxrm69/4OmaRuX9kwec= +gopkg.in/jcmturner/gokrb5.v5 v5.3.0/go.mod h1:oQz8Wc5GsctOTgCVyKad1Vw4TCWz5G6gfIQr88RPv4k= +gopkg.in/jcmturner/rpc.v0 v0.0.2 h1:wBTgrbL1qmLBUPsYVCqdJiI5aJgQhexmK+JkTHPUNJI= +gopkg.in/jcmturner/rpc.v0 v0.0.2/go.mod h1:NzMq6cRzR9lipgw7WxRBHNx5N8SifBuaCQsOT1kWY/E= gopkg.in/yaml.v2 v2.3.0 h1:clyUAQHOM3G0M3f5vQj7LuJrETvjVot3Z5el9nffUtU= gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/httpx/transport.go b/httpx/transport.go index ce8fb00..190076a 100644 --- a/httpx/transport.go +++ b/httpx/transport.go @@ -1406,8 +1406,12 @@ func addTLS(t *Transport, tlsConfig *tls.Config, name string, conn net.Conn, tra func (t *Transport) doDialFirstHop(ctx context.Context, cm ConnectMethod, trace *httptrace.ClientTrace) (conn net.Conn, tlsState *tls.ConnectionState, err error) { firstHopScheme := cm.Scheme() - if firstHopScheme == "https" && t.DialTLS != nil { - conn, err = t.DialTLS(ctx, "tcp", cm.Addr()) + if firstHopScheme == "https" && (t.DialTLS2 != nil || t.DialTLS != nil) { + if t.DialTLS2 != nil { + conn, err = t.DialTLS2(ctx, "tcp", cm.Addr(), t.TLSClientConfig) + } else { + conn, err = t.DialTLS(ctx, "tcp", cm.Addr()) + } if err != nil { return } diff --git a/main.go b/main.go index e66dc8e..26cc83a 100644 --- a/main.go +++ b/main.go @@ -34,6 +34,7 @@ package main import ( + "context" crand "crypto/rand" "crypto/sha1" "crypto/tls" @@ -64,6 +65,81 @@ import ( type TLSConfigFactory func(hostPortPairStr string, proxyCtx *OurProxyCtx) (*tls.Config, error) +type HttpxTransport interface { + http.RoundTripper + RegisterProtocol(string, http.RoundTripper) + CloseIdleConnections() + CancelRequest(*http.Request, error) + ConnectMethodForRequest(*httpx.TransportRequest) (httpx.ConnectMethod, error) + DoDial(context.Context, httpx.ConnectMethod) (net.Conn, *tls.ConnectionState, bool, func(http.Header), error) + DialContext(context.Context, string, string) (net.Conn, error) + DialTLS2(context.Context, string, string, *tls.Config) (net.Conn, error) +} + +type httpxTransportWrapper struct { + *httpx.Transport +} + +func (htw *httpxTransportWrapper) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { + if htw.Transport.DialContext != nil { + return htw.Transport.DialContext(ctx, network, addr) + } + return (&net.Dialer{}).DialContext(ctx, "tcp", addr) +} + +func (htw *httpxTransportWrapper) DialTLS2(ctx context.Context, network, addr string, config *tls.Config) (net.Conn, error) { + if htw.Transport.DialTLS2 != nil { + return htw.Transport.DialTLS2(ctx, network, addr, config) + } + if config == nil { + if htw.Transport.DialTLS != nil { + return htw.Transport.DialTLS(ctx, "tcp", addr) + } else { + config = htw.Transport.TLSClientConfig + } + } + conn, err := net.Dial("tcp", addr) + if err != nil { + return nil, errors.Wrapf(err, "failed to connect to %v", addr) + } + tlsConfig := config.Clone() + tlsConfig.ServerName = splitHostPort(addr).Host + return tls.Client(conn, tlsConfig), nil +} + +func trAsInterface(tr *httpx.Transport) HttpxTransport { + return &httpxTransportWrapper{tr} +} + +type transportRoundTripperWrapper struct { + HttpxTransport + wrapper RoundTripperWrapper +} + +func (ttw *transportRoundTripperWrapper) RoundTrip(req *http.Request) (*http.Response, error) { + return ttw.wrapper(ttw, req) +} + +type roundTripperWrapper struct { + http.RoundTripper + wrapper RoundTripperWrapper +} + +func (rtw *roundTripperWrapper) RoundTrip(req *http.Request) (*http.Response, error) { + return rtw.wrapper(rtw, req) +} + +func wrapRoundTripper(tr http.RoundTripper, rtw RoundTripperWrapper) (http.RoundTripper, error) { + switch tr := tr.(type) { + case HttpxTransport: + return &transportRoundTripperWrapper{tr, rtw}, nil + case http.RoundTripper: + return &roundTripperWrapper{tr, rtw}, nil + default: + return nil, errors.Errorf("invalid round tripper: %T", tr) + } +} + type DevProxy struct { Logger *logrus.Logger LogWriter io.WriteCloser @@ -180,15 +256,28 @@ func (ctx *DevProxy) newProxyURLBuilder() func(*http.Request) (*url.URL, *tls.Co } } -func (ctx *DevProxy) newHttpTransport() *httpx.Transport { - transport := &httpx.Transport{ +func (ctx *DevProxy) newHttpTransport() (tr HttpxTransport, err error) { + tr = trAsInterface(&httpx.Transport{ TLSClientConfig: ctx.Config.MITM.ClientTLSConfigTemplate, Proxy2: ctx.newProxyURLBuilder(), + }) + tr.RegisterProtocol("fastcgi", &fastCGIRoundTripper{Logger: ctx.Logger}) + tr.RegisterProtocol("file", NewFileTransport(ctx.Config.FileTransport)) + tr.RegisterProtocol("x-http-redirect", &redirector{Logger: ctx.Logger}) + if rtwf := ctx.Config.Proxy.Auth.RoundTripperWrapperFactory; rtwf != nil { + var rt RoundTripperWrapper + rt, err = rtwf() + if err != nil { + return + } + var wrt http.RoundTripper + wrt, err = wrapRoundTripper(tr, rt) + if err != nil { + return + } + tr = wrt.(HttpxTransport) } - transport.RegisterProtocol("fastcgi", &fastCGIRoundTripper{Logger: ctx.Logger}) - transport.RegisterProtocol("file", NewFileTransport(ctx.Config.FileTransport)) - transport.RegisterProtocol("x-http-redirect", &redirector{Logger: ctx.Logger}) - return transport + return } var domainNameRegex = regexp.MustCompile("^[A-Za-z](?:[0-9A-Za-z-_]*[0-9A-Za-z])?$") @@ -366,15 +455,19 @@ func (ctx *DevProxy) checkIfTunnelRequestMatchesToUrl(url_ *url.URL, req *http.R return false } -func (ctx *DevProxy) newProxyHttpServer() *OurProxyHttpServer { +func (ctx *DevProxy) newProxyHttpServer() (*OurProxyHttpServer, error) { + tr, err := ctx.newHttpTransport() + if err != nil { + return nil, err + } return &OurProxyHttpServer{ Ctx: ctx, Logger: ctx.Logger, - Tr: ctx.newHttpTransport(), + Tr: tr, TLSConfigFactory: ctx.newTLSConfigFactory(), ResponseFilters: ctx.Config.ResponseFilters, SessionSerial: 0, - } + }, nil } func (ctx *DevProxy) Dispose() { @@ -445,7 +538,11 @@ func main() { ), } defer ctx.Dispose() - proxy := ctx.newProxyHttpServer() - logger.Infof("Listening on %s...", listenOn) + proxy, err := ctx.newProxyHttpServer() + if err != nil { + logger.Fatalf("could not initialize the proxy server: %s", err.Error()) + os.Exit(1) + } + logger.Infof("listening on %s...", listenOn) logger.Fatal(http.ListenAndServe(listenOn, proxy)) } diff --git a/server.go b/server.go index 8abc379..2aa866c 100644 --- a/server.go +++ b/server.go @@ -59,7 +59,7 @@ type ResponseFilter interface { type OurProxyHttpServer struct { Ctx *DevProxy Logger *logrus.Logger - Tr *httpx.Transport + Tr HttpxTransport TLSConfigFactory TLSConfigFactory ResponseFilters []ResponseFilter SessionSerial int64 @@ -72,7 +72,7 @@ type OurProxyCtx struct { Req *http.Request OrigResp *http.Response Resp *http.Response - Tr *httpx.Transport + Tr HttpxTransport ResponseFilters []ResponseFilter Error error Session int64 @@ -538,28 +538,13 @@ func buildFakeHTTPSRequestFromHostPortPair(addr string) *http.Request { } } -func (proxy *OurProxyHttpServer) doDial(ctx context.Context, addr string) (net.Conn, error) { - if proxy.Tr.DialContext != nil { - return proxy.Tr.DialContext(ctx, "tcp", addr) - } - return (&net.Dialer{}).DialContext(ctx, "tcp", addr) -} - -func (proxy *OurProxyHttpServer) doDialTLS(ctx context.Context, addr HostPortPair, tlsConfigTemplate *tls.Config) (net.Conn, error) { - if tlsConfigTemplate == nil { - if proxy.Tr.DialTLS != nil { - return proxy.Tr.DialTLS(ctx, "tcp", addr.String()) - } - tlsConfigTemplate = proxy.Tr.TLSClientConfig - } - conn, err := net.Dial("tcp", addr.String()) - if err != nil { - return nil, errors.Wrapf(err, "failed to connect to %v", addr) - } - tlsConfig := tlsConfigTemplate.Clone() - tlsConfig.ServerName = addr.Host - return tls.Client(conn, tlsConfig), nil -} +// func (proxy *OurProxyHttpServer) doDial(ctx context.Context, addr string) (net.Conn, error) { +// return proxy.Tr.DialContext(ctx, "tcp", addr) +// } +// +// func (proxy *OurProxyHttpServer) doDialTLS(ctx context.Context, addr HostPortPair, tlsConfigTemplate *tls.Config) (net.Conn, error) { +// return proxy.Tr.DialTLS2(ctx, "tcp", addr.String(), tlsConfigTemplate) +// } func (proxy *OurProxyHttpServer) ConnectDial(netCtx context.Context, addr string) (net.Conn, error) { cm, err := proxy.Tr.ConnectMethodForRequest(&httpx.TransportRequest{Request: buildFakeHTTPSRequestFromHostPortPair(addr), Extra: nil})