Skip to content

Commit 19eaf8b

Browse files
authored
Merge pull request #1 from aws-gopher/type-safe
Refactor SDK to be type safe
2 parents dd7d363 + 1b61a41 commit 19eaf8b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

73 files changed

+2872
-57
lines changed

.golangci.yml

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,14 @@ linters:
5858
check-type-assertions: false
5959
exclude-functions:
6060
- io/ioutil.ReadFile
61-
- io.Copy(*bytes.Buffer)
62-
- io.Copy(os.Stdout)
61+
- io.Copy
6362
- (io.Closer).Close
63+
- (net/http.ResponseWriter).Write
64+
- io.Writer.Write
65+
gosec:
66+
excludes:
67+
- G101
68+
- G104
6469

6570
# Issues configuration
6671
issues:

bearer.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,12 @@ type bearer struct {
99
rt http.RoundTripper
1010
}
1111

12+
// HeaderKey is "Unstructured-API-Key", which is the header where Unstructured expects to find the API key.
13+
const HeaderKey = "Unstructured-API-Key"
14+
1215
// RoundTrip implements the http.RoundTripper interface.
1316
func (b *bearer) RoundTrip(req *http.Request) (*http.Response, error) {
14-
req.Header.Set("Unstructured-API-Key", b.key)
17+
req.Header.Set(HeaderKey, b.key)
1518

1619
// This is implementing the http.RoundTripper interface, errors should be passed through as-is
1720
return b.rt.RoundTrip(req) //nolint:wrapcheck

block_types.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
package unstructured
2+
3+
// BlockType is a type that represents a block type.
4+
type BlockType string
5+
6+
// BlockType constants.
7+
const (
8+
BlockTypeImage BlockType = "Image"
9+
BlockTypeTable BlockType = "Table"
10+
)

chunker_character.go

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
package unstructured
2+
3+
import (
4+
"encoding/json"
5+
"fmt"
6+
)
7+
8+
// ChunkerCharacter is a node that chunks text by character.
9+
type ChunkerCharacter struct {
10+
ID string `json:"-"`
11+
Name string `json:"-"`
12+
APIURL string `json:"unstructured_api_url,omitempty"`
13+
APIKey string `json:"unstructured_api_key,omitempty"`
14+
IncludeOrigElements bool `json:"include_orig_elements,omitempty"`
15+
NewAfterNChars int `json:"new_after_n_chars,omitempty"`
16+
MaxCharacters int `json:"max_characters,omitempty"`
17+
Overlap int `json:"overlap,omitempty"`
18+
OverlapAll bool `json:"overlap_all"`
19+
ContextualChunkingStrategy ChunkingStrategy `json:"contextual_chunking_strategy,omitempty"`
20+
}
21+
22+
// ChunkingStrategy is a strategy for contextual chunking.
23+
type ChunkingStrategy string
24+
25+
// ChunkingStrategyV1 is a strategy for contextual chunking.
26+
const ChunkingStrategyV1 = "v1"
27+
28+
var _ WorkflowNode = new(ChunkerCharacter)
29+
30+
// isNode implements the WorkflowNode interface.
31+
func (c ChunkerCharacter) isNode() {}
32+
33+
// MarshalJSON implements the json.Marshaler interface.
34+
func (c ChunkerCharacter) MarshalJSON() ([]byte, error) {
35+
type alias ChunkerCharacter
36+
37+
data, err := json.Marshal(alias(c))
38+
if err != nil {
39+
return nil, fmt.Errorf("failed to marshal chunker character: %w", err)
40+
}
41+
42+
headerData, err := json.Marshal(header{
43+
ID: c.ID,
44+
Name: c.Name,
45+
Type: nodeTypeChunk,
46+
Subtype: string(ChunkerSubtypeCharacter),
47+
Settings: json.RawMessage(data),
48+
})
49+
if err != nil {
50+
return nil, fmt.Errorf("failed to marshal chunker character header: %w", err)
51+
}
52+
53+
return headerData, nil
54+
}

chunker_page.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
package unstructured
2+
3+
import (
4+
"encoding/json"
5+
"fmt"
6+
)
7+
8+
// ChunkerPage is a node that chunks text by character.
9+
type ChunkerPage struct {
10+
ID string `json:"-"`
11+
Name string `json:"-"`
12+
APIURL string `json:"unstructured_api_url,omitempty"`
13+
APIKey string `json:"unstructured_api_key,omitempty"`
14+
IncludeOrigElements bool `json:"include_orig_elements,omitempty"`
15+
NewAfterNChars int `json:"new_after_n_chars,omitempty"`
16+
MaxCharacters int `json:"max_characters,omitempty"`
17+
Overlap int `json:"overlap,omitempty"`
18+
OverlapAll bool `json:"overlap_all"`
19+
Strategy ChunkingStrategy `json:"contextual_chunking_strategy,omitempty"`
20+
}
21+
22+
var _ WorkflowNode = new(ChunkerPage)
23+
24+
// isNode implements the WorkflowNode interface.
25+
func (c ChunkerPage) isNode() {}
26+
27+
// MarshalJSON implements the json.Marshaler interface.
28+
func (c ChunkerPage) MarshalJSON() ([]byte, error) {
29+
type alias ChunkerPage
30+
31+
data, err := json.Marshal(alias(c))
32+
if err != nil {
33+
return nil, fmt.Errorf("failed to marshal chunker page: %w", err)
34+
}
35+
36+
headerData, err := json.Marshal(header{
37+
ID: c.ID,
38+
Name: c.Name,
39+
Type: nodeTypeChunk,
40+
Subtype: string(ChunkerSubtypePage),
41+
Settings: json.RawMessage(data),
42+
})
43+
if err != nil {
44+
return nil, fmt.Errorf("failed to marshal chunker page header: %w", err)
45+
}
46+
47+
return headerData, nil
48+
}

chunker_similarity.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
package unstructured
2+
3+
import (
4+
"encoding/json"
5+
"fmt"
6+
)
7+
8+
// ChunkerSimilarity is a node that chunks text by character.
9+
type ChunkerSimilarity struct {
10+
ID string `json:"-"`
11+
Name string `json:"-"`
12+
APIURL string `json:"unstructured_api_url,omitempty"`
13+
APIKey string `json:"unstructured_api_key,omitempty"`
14+
IncludeOrigElements bool `json:"include_orig_elements,omitempty"`
15+
NewAfterNChars int `json:"new_after_n_chars,omitempty"`
16+
MaxCharacters int `json:"max_characters,omitempty"`
17+
Overlap int `json:"overlap,omitempty"`
18+
OverlapAll bool `json:"overlap_all"`
19+
Strategy ChunkingStrategy `json:"contextual_chunking_strategy,omitempty"`
20+
}
21+
22+
var _ WorkflowNode = new(ChunkerSimilarity)
23+
24+
// isNode implements the WorkflowNode interface.
25+
func (c ChunkerSimilarity) isNode() {}
26+
27+
// MarshalJSON implements the json.Marshaler interface.
28+
func (c ChunkerSimilarity) MarshalJSON() ([]byte, error) {
29+
type alias ChunkerSimilarity
30+
31+
data, err := json.Marshal(alias(c))
32+
if err != nil {
33+
return nil, fmt.Errorf("failed to marshal chunker similarity: %w", err)
34+
}
35+
36+
headerData, err := json.Marshal(header{
37+
ID: c.ID,
38+
Name: c.Name,
39+
Type: nodeTypeChunk,
40+
Subtype: string(ChunkerSubtypeSimilarity),
41+
Settings: json.RawMessage(data),
42+
})
43+
if err != nil {
44+
return nil, fmt.Errorf("failed to marshal chunker similarity header: %w", err)
45+
}
46+
47+
return headerData, nil
48+
}

chunker_title.go

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
package unstructured
2+
3+
import (
4+
"encoding/json"
5+
"fmt"
6+
)
7+
8+
// ChunkerTitle is a node that chunks text by character.
9+
type ChunkerTitle struct {
10+
ID string `json:"-"`
11+
Name string `json:"-"`
12+
APIURL string `json:"unstructured_api_url,omitempty"`
13+
APIKey string `json:"unstructured_api_key,omitempty"`
14+
CombineTextUnderN int `json:"combine_text_under_n_chars,omitempty"`
15+
IncludeOrigElements bool `json:"include_orig_elements,omitempty"`
16+
NewAfterNChars int `json:"new_after_n_chars,omitempty"`
17+
MaxCharacters int `json:"max_characters,omitempty"`
18+
Overlap int `json:"overlap,omitempty"`
19+
OverlapAll bool `json:"overlap_all"`
20+
ContextualChunkingStrategy ChunkingStrategy `json:"contextual_chunking_strategy,omitempty"`
21+
}
22+
23+
var _ WorkflowNode = new(ChunkerTitle)
24+
25+
// isNode implements the WorkflowNode interface.
26+
func (c ChunkerTitle) isNode() {}
27+
28+
// MarshalJSON implements the json.Marshaler interface.
29+
func (c ChunkerTitle) MarshalJSON() ([]byte, error) {
30+
type alias ChunkerTitle
31+
32+
data, err := json.Marshal(alias(c))
33+
if err != nil {
34+
return nil, fmt.Errorf("failed to marshal chunker title: %w", err)
35+
}
36+
37+
headerData, err := json.Marshal(header{
38+
ID: c.ID,
39+
Name: c.Name,
40+
Type: nodeTypeChunk,
41+
Subtype: string(ChunkerSubtypeTitle),
42+
Settings: json.RawMessage(data),
43+
})
44+
if err != nil {
45+
return nil, fmt.Errorf("failed to marshal chunker title header: %w", err)
46+
}
47+
48+
return headerData, nil
49+
}

chunker_type.go

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
package unstructured
2+
3+
import (
4+
"encoding/json"
5+
"fmt"
6+
)
7+
8+
// ChunkerSubtype is a type that represents a chunker subtype.
9+
type ChunkerSubtype string
10+
11+
// ChunkerSubtype constants.
12+
const (
13+
ChunkerSubtypeCharacter ChunkerSubtype = "chunk_by_character"
14+
ChunkerSubtypeTitle ChunkerSubtype = "chunk_by_title"
15+
ChunkerSubtypePage ChunkerSubtype = "chunk_by_page"
16+
ChunkerSubtypeSimilarity ChunkerSubtype = "chunk_by_similarity"
17+
)
18+
19+
func unmarshalChunker(header header) (WorkflowNode, error) {
20+
var chunker WorkflowNode
21+
22+
switch ChunkerSubtype(header.Subtype) {
23+
case ChunkerSubtypeCharacter:
24+
chunker = &ChunkerCharacter{
25+
ID: header.ID,
26+
Name: header.Name,
27+
}
28+
29+
case ChunkerSubtypeTitle:
30+
chunker = &ChunkerTitle{
31+
ID: header.ID,
32+
Name: header.Name,
33+
}
34+
35+
case ChunkerSubtypePage:
36+
chunker = &ChunkerPage{
37+
ID: header.ID,
38+
Name: header.Name,
39+
}
40+
41+
case ChunkerSubtypeSimilarity:
42+
chunker = &ChunkerSimilarity{
43+
ID: header.ID,
44+
Name: header.Name,
45+
}
46+
47+
default:
48+
return nil, fmt.Errorf("unknown Chunker strategy: %s", header.Subtype)
49+
}
50+
51+
if err := json.Unmarshal(header.Settings, chunker); err != nil {
52+
return nil, fmt.Errorf("failed to unmarshal Chunker node: %w", err)
53+
}
54+
55+
return chunker, nil
56+
}

client.go

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@ package unstructured
33
import (
44
"cmp"
55
"encoding/json"
6+
"errors"
67
"fmt"
8+
"io"
79
"net/http"
810
"net/url"
911
"os"
@@ -55,6 +57,15 @@ func WithKey(key string) Option {
5557
}
5658
}
5759

60+
// WithClient returns an Option that sets the HTTP client to use for requests.
61+
// If no client is provided, the client will default to [http.DefaultClient].
62+
func WithClient(hc *http.Client) Option {
63+
return func(c *Client) error {
64+
c.hc = hc
65+
return nil
66+
}
67+
}
68+
5869
// New creates a new Client instance with the provided options.
5970
// If the `UNSTRUCTURED_API_KEY` environment variable is set, it will be used as the API key for authentication.
6071
// If the `UNSTRUCTURED_API_URL` environment variable is set to a valid URL, it will be used as the base URL for the Unstructured.io API.
@@ -104,17 +115,26 @@ func (c *Client) do(req *http.Request, out any) error {
104115
defer func() { _ = resp.Body.Close() }()
105116

106117
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
118+
body, err := io.ReadAll(resp.Body)
119+
if err != nil {
120+
return fmt.Errorf("failed to read response body: %w", err)
121+
}
122+
107123
// Handle 422 validation errors specifically
108124
if resp.StatusCode == http.StatusUnprocessableEntity {
109125
var validationErr HTTPValidationError
110-
if err := json.NewDecoder(resp.Body).Decode(&validationErr); err != nil {
111-
return fmt.Errorf("failed to decode validation error response: %w", err)
126+
if err := json.Unmarshal(body, &validationErr); err == nil {
127+
return &APIError{
128+
Code: resp.StatusCode,
129+
Err: &validationErr,
130+
}
112131
}
113-
114-
return &validationErr
115132
}
116133

117-
return fmt.Errorf("unsuccessful response: %s", resp.Status)
134+
return &APIError{
135+
Code: resp.StatusCode,
136+
Err: errors.New(string(body)),
137+
}
118138
}
119139

120140
if out != nil {

0 commit comments

Comments
 (0)