1- package http
1+ package http_test
22
33import (
44 "encoding/json"
5- "io/ioutil "
5+ "errors "
66 "net/http"
77 "net/http/httptest"
88 "net/http/httputil"
9- "strings "
9+ "net/url "
1010 "testing"
1111 "time"
1212
1313 itn "github.com/1set/starlet/internal"
14+ lh "github.com/1set/starlet/lib/http"
1415 "github.com/1set/starlight/convert"
1516 "go.starlark.net/starlark"
1617 "go.starlark.net/starlarktest"
@@ -27,7 +28,7 @@ func TestAsString(t *testing.T) {
2728 }
2829
2930 for i , c := range cases {
30- got , err := AsString (c .in )
31+ got , err := lh . AsString (c .in )
3132 if ! (err == nil && c .err == "" || err != nil && err .Error () == c .err ) {
3233 t .Errorf ("case %d error mismatch. expected: '%s', got: '%s'" , i , c .err , err )
3334 continue
@@ -49,7 +50,7 @@ func TestLoadModule_HTTP_One(t *testing.T) {
4950 defer ts .Close ()
5051 starlark .Universe ["test_server_url" ] = starlark .String (ts .URL )
5152
52- thread := & starlark.Thread {Load : itn .NewAssertLoader (ModuleName , LoadModule )}
53+ thread := & starlark.Thread {Load : itn .NewAssertLoader (lh . ModuleName , lh . LoadModule )}
5354 starlarktest .SetReporter (thread , t )
5455
5556 code := itn .HereDoc (`
@@ -296,7 +297,7 @@ func TestLoadModule_HTTP(t *testing.T) {
296297 {
297298 name : `POST with UA Set` ,
298299 preset : func () {
299- UserAgent = "GqQdYX3eIJw2DTt"
300+ lh . UserAgent = "GqQdYX3eIJw2DTt"
300301 },
301302 script : itn .HereDoc (`
302303 load('http', 'post')
@@ -575,8 +576,8 @@ func TestLoadModule_HTTP(t *testing.T) {
575576 if tt .preset != nil {
576577 tt .preset ()
577578 }
578- TimeoutSecond = 30.0
579- res , err := itn .ExecModuleWithErrorTest (t , ModuleName , LoadModule , tt .script , tt .wantErr , nil )
579+ lh . TimeoutSecond = 30.0
580+ res , err := itn .ExecModuleWithErrorTest (t , lh . ModuleName , lh . LoadModule , tt .script , tt .wantErr , nil )
580581 if (err != nil ) != (tt .wantErr != "" ) {
581582 t .Errorf ("http(%q) expects error = '%v', actual error = '%v', result = %v" , tt .name , tt .wantErr , err , res )
582583 return
@@ -585,65 +586,91 @@ func TestLoadModule_HTTP(t *testing.T) {
585586 }
586587}
587588
588- // we're ok with testing private functions if it simplifies the test :)
589- func TestSetBody (t * testing.T ) {
590- fd := map [string ]string {
591- "foo" : "bar baz" ,
589+ // DomainWhitelistGuard allows requests only to domains in its whitelist.
590+ type DomainWhitelistGuard struct {
591+ whitelist map [string ]struct {} // Set of allowed domains
592+ }
593+
594+ // NewDomainWhitelistGuard creates a new DomainWhitelistGuard with the specified domains.
595+ func NewDomainWhitelistGuard (domains []string ) * DomainWhitelistGuard {
596+ whitelist := make (map [string ]struct {})
597+ for _ , domain := range domains {
598+ whitelist [domain ] = struct {}{}
592599 }
600+ return & DomainWhitelistGuard {whitelist : whitelist }
601+ }
593602
594- cases := []struct {
595- rawBody starlark.String
596- formData map [string ]string
597- formEncoding starlark.String
598- jsonData starlark.Value
599- body string
600- err string
601- }{
602- {starlark .String ("hallo" ), nil , starlark .String ("" ), nil , "hallo" , "" },
603- {starlark .String ("" ), fd , starlark .String ("" ), nil , "foo=bar+baz" , "" },
604- // TODO - this should check multipart form data is being set
605- {starlark .String ("" ), fd , starlark .String ("multipart/form-data" ), nil , "" , "" },
606- {starlark .String ("" ), nil , starlark .String ("" ), starlark.Tuple {starlark .Bool (true ), starlark .MakeInt (1 ), starlark .String ("der" )}, "[true,1,\" der\" ]" , "" },
603+ // Allowed checks if the request's domain is in the whitelist.
604+ func (g * DomainWhitelistGuard ) Allowed (thread * starlark.Thread , req * http.Request ) (* http.Request , error ) {
605+ if _ , ok := g .whitelist [req .URL .Host ]; ok {
606+ // Domain is in the whitelist, allow the request
607+ return req , nil
607608 }
609+ // Domain is not in the whitelist, deny the request
610+ return nil , errors .New ("request to this domain is not allowed" )
611+ }
608612
609- for i , c := range cases {
610- var formData * starlark.Dict
611- if c .formData != nil {
612- formData = starlark .NewDict (len (c .formData ))
613- for k , v := range c .formData {
614- if err := formData .SetKey (starlark .String (k ), starlark .String (v )); err != nil {
615- t .Fatal (err )
616- }
617- }
618- }
613+ func TestLoadModule_CustomLoad (t * testing.T ) {
614+ md := lh .NewModule ()
615+ proxyURL , _ := url .Parse ("http://127.0.0.1:9999" )
616+ client := & http.Client {
617+ Transport : & http.Transport {
618+ Proxy : http .ProxyURL (proxyURL ),
619+ },
620+ }
621+ md .SetClient (client )
622+ guard := NewDomainWhitelistGuard ([]string {"allowed.com" })
623+ md .SetGuard (guard )
619624
620- req := httptest .NewRequest ("get" , "https://example.com" , nil )
621- err := setBody (req , c .rawBody , formData , c .formEncoding , c .jsonData )
622- if ! (err == nil && c .err == "" || (err != nil && err .Error () == c .err )) {
623- t .Errorf ("case %d error mismatch. expected: %s, got: %s" , i , c .err , err )
624- continue
625+ httpHand := http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
626+ b , err := httputil .DumpRequest (r , true )
627+ if err != nil {
628+ t .Errorf ("Error dumping request: %v" , err )
625629 }
630+ t .Logf ("Web server received request: [[%s]]" , b )
631+ time .Sleep (10 * time .Millisecond )
632+ w .Write (b )
633+ })
634+ ts := httptest .NewServer (httpHand )
635+ defer ts .Close ()
626636
627- if strings .HasPrefix (req .Header .Get ("Content-Type" ), "multipart/form-data;" ) {
628- if err := req .ParseMultipartForm (0 ); err != nil {
629- t .Fatal (err )
630- }
631-
632- for k , v := range c .formData {
633- fv := req .FormValue (k )
634- if fv != v {
635- t .Errorf ("case %d error mismatch. expected %s=%s, got: %s" , i , k , v , fv )
636- }
637- }
638- } else {
639- body , err := ioutil .ReadAll (req .Body )
640- if err != nil {
641- t .Fatal (err )
637+ tests := []struct {
638+ name string
639+ preset func ()
640+ script string
641+ wantErr string
642+ }{
643+ {
644+ name : `Simple GET` ,
645+ script : itn .HereDoc (`
646+ load('http', 'get')
647+ res = get("http://allowed.com/hello")
648+ assert.eq(res.status_code, 200)
649+ ` ),
650+ wantErr : `proxyconnect tcp: dial tcp 127.0.0.1:9999: connect` ,
651+ },
652+ {
653+ name : `Not Allowed` ,
654+ script : itn .HereDoc (`
655+ load('http', 'get')
656+ res = get("http://topsecret.com/text")
657+ assert.eq(res.status_code, 200)
658+ ` ),
659+ wantErr : `request to this domain is not allowed` ,
660+ },
661+ }
662+ for _ , tt := range tests {
663+ t .Run (tt .name , func (t * testing.T ) {
664+ if tt .preset != nil {
665+ tt .preset ()
642666 }
643-
644- if string (body ) != c .body {
645- t .Errorf ("case %d body mismatch. expected: %s, got: %s" , i , c .body , string (body ))
667+ res , err := itn .ExecModuleWithErrorTest (t , lh .ModuleName , md .LoadModule , tt .script , tt .wantErr , starlark.StringDict {
668+ "test_server_url" : starlark .String (ts .URL ),
669+ })
670+ if (err != nil ) != (tt .wantErr != "" ) {
671+ t .Errorf ("http(%q) expects error = '%v', actual error = '%v', result = %v" , tt .name , tt .wantErr , err , res )
672+ return
646673 }
647- }
674+ })
648675 }
649676}
0 commit comments