diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 0000000..889ae34 --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,7 @@ +// For format details, see https://aka.ms/devcontainer.json. For config options, see the +// README at: https://github.com/devcontainers/templates/tree/main/src/debian +{ + "name": "Development", + "image": "mcr.microsoft.com/devcontainers/go:1.23-bookworm", + "postCreateCommand": "go mod tidy" +} diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..fe48fb3 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,49 @@ +name: CI +on: + push: + branches-ignore: + - 'generated' + - 'codegen/**' + - 'integrated/**' + - 'stl-preview-head/**' + - 'stl-preview-base/**' + pull_request: + branches-ignore: + - 'stl-preview-head/**' + - 'stl-preview-base/**' + +jobs: + lint: + timeout-minutes: 10 + name: lint + runs-on: ${{ github.repository == 'stainless-sdks/scrapegraphai-sdk-go' && 'depot-ubuntu-24.04' || 'ubuntu-latest' }} + if: github.event_name == 'push' || github.event.pull_request.head.repo.fork + + steps: + - uses: actions/checkout@v4 + + - name: Setup go + uses: actions/setup-go@v5 + with: + go-version-file: ./go.mod + + - name: Run lints + run: ./scripts/lint + test: + timeout-minutes: 10 + name: test + runs-on: ${{ github.repository == 'stainless-sdks/scrapegraphai-sdk-go' && 'depot-ubuntu-24.04' || 'ubuntu-latest' }} + if: github.event_name == 'push' || github.event.pull_request.head.repo.fork + steps: + - uses: actions/checkout@v4 + + - name: Setup go + uses: actions/setup-go@v5 + with: + go-version-file: ./go.mod + + - name: Bootstrap + run: ./scripts/bootstrap + + - name: Run tests + run: ./scripts/test diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..c6d0501 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +.prism.log +codegen.log +Brewfile.lock.json +.idea/ diff --git a/.release-please-manifest.json b/.release-please-manifest.json new file mode 100644 index 0000000..ba6c348 --- /dev/null +++ b/.release-please-manifest.json @@ -0,0 +1,3 @@ +{ + ".": "0.1.0-alpha.1" +} \ No newline at end of file diff --git a/.stats.yml b/.stats.yml new file mode 100644 index 0000000..2afc9f9 --- /dev/null +++ b/.stats.yml @@ -0,0 +1,4 @@ +configured_endpoints: 13 +openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/scrapegraphai%2Fscrapegraphai-sdk-1bb59000acb1297b8bbff37e8e34ed4751b9708a050695796e64ed3a8900ceef.yml +openapi_spec_hash: 141d244ff26ff28ab56901a19fb46347 +config_hash: 841efa635faf188bb88c338627bf9658 diff --git a/Brewfile b/Brewfile new file mode 100644 index 0000000..577e34a --- /dev/null +++ b/Brewfile @@ -0,0 +1 @@ +brew "go" diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..b0a150f --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,144 @@ +# Changelog + +## 0.1.0-alpha.1 (2025-08-04) + +Full Changelog: [v0.0.1-alpha.0...v0.1.0-alpha.1](https://github.com/ScrapeGraphAI/scrapegraph-sdk/compare/v0.0.1-alpha.0...v0.1.0-alpha.1) + +### Features + +* add client integration ([cad027f](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/cad027fcdb9263f6386edd3a48837a87c884e2c0)) +* add docstring ([d486a7a](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/d486a7abac41dcf5af6ab1f15da63ed6bcfe4743)) +* add integration for env variables ([6a351f3](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/6a351f3ef70a1f00b5f5de5aaba2f408b6bf07dd)) +* add integration for local_scraper ([b9a17d5](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/b9a17d517413686732edf277a0ec978a5df00992)) +* add integration for sql ([2543b5a](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/2543b5a9b84826de5c583d38fe89cf21aad077e6)) +* add integration for the api ([c3ddf26](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/c3ddf269c70fbb794fd98bcca42c8ad74e96d70f)) +* add localScraper functionality ([8701eb2](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/8701eb2ca7f108b922eb1617c850a58c0f88f8f9)) +* add markdownify and localscraper ([6296510](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/6296510b22ce511adde4265532ac6329a05967e0)) +* add markdownify functionality ([239d27a](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/239d27aac28c6b132aba54bbb1fa0216cc59ce89)) +* add optional headers to request ([bb851d7](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/bb851d785d121b039d5e968327fb930955a3fd92)) +* add requirement files ([06f84c6](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/06f84c6a31eaf1d19b10f7bf48a4f3dd5b44b4b1)) +* add scrapegraphai api integration ([4effb03](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/4effb03eb3ce515f35d15a1bf5e6683bd9afe16b)) +* add time varying timeout ([945b876](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/945b876a0c23d4b2a29ef916bd6fa9af425f9ab5)) +* added example of the smartScraper function using a schema ([baf933b](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/baf933b0826b63d4ecf61c8593676357619a1c73)) +* changed SyncClient to Client ([9e1e496](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/9e1e496059cd24810a96b818da1811830586f94b)) +* check ([9871ff8](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/9871ff81acfb42031ee9db526a7dba9e29d3c55b)) +* enhaced python sdk ([c253363](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/c2533636c230426be06cd505598e8a85d5771cbc)) +* final release maybe semantic? ([8ce3ccd](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/8ce3ccd3509d0487da212f541e039ee7009dd8f3)) +* fix ([d81ab09](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/d81ab091aa1ff08927ed7765055764b9e51083ee)) +* implemented support for requests with schema ([10a1a5a](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/10a1a5a477a6659aabf3afebfffdbefc14d12d3e)) +* maybe final release? ([595c3c6](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/595c3c6b6ca0e8eaacd5959422ab9018516f3fa8)) +* merged localscraper into smartscraper ([503dbd1](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/503dbd19b8cec4d2ff4575786b0eec25db2e80e6)) +* modified icons ([bcb9b0b](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/bcb9b0b731b057d242fdf80b43d96879ff7a2764)) +* new release ([3d08e4e](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/3d08e4ea9c940a51bbe011107a6b7568dfcab54b)) +* refactoring of the folders ([908e67f](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/908e67fba5bbd351a17e9e535c1eb7b652958896)) +* refctoring of the folder ([ce02854](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/ce0285432ec62a7b1566c87408ede275710487ba)) +* removed local scraper ([9edfc67](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/9edfc6747d2604e52d7ee658e7cac1862cff89bf)) +* revert to old release ([d88a3ac](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/d88a3ac6969a0abdf1f6b8eccde9ad8284d41d20)) +* searchscraper ([2e04e5a](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/2e04e5a1bbd207a7ceeea594878bdea542a7a856)) +* semantic relaase ([30ff13a](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/30ff13a219df982e07df7b5366f09dedc0892de5)) +* semantic release ([6df4b18](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/6df4b1833c8c418766b1649f80f9d6cd1fa8a201)) +* semantic release ([edd23d9](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/edd23d93375ef33fa97a0b409045fdbd18090d10)) +* semantic release ([e5e4908](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/e5e49080bc6d3d1440d6b333f9cadfd493ff0449)) +* splitted files ([2791691](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/2791691a9381063cc38ac4f4fe7c884166c93116)) +* test ([3bb66c4](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/3bb66c4efe3eb5407f6eb88d31bda678ac3651b3)) +* test semantic release ([63d3a36](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/63d3a3623363c358e5761e1b7737f262c8238c82)) +* test semantic release ([19eda59](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/19eda59be7adbea80ed189fd0af85ab0c3c930bd)) +* test semantic release ([3e611f2](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/3e611f21248a46120fa8ff3d30392522f6d1419a)) +* test semantic release ([6320819](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/6320819e12cbd3e0fa3faa93179d2d26f1323bb4)) +* try semantic release ([d953723](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/d9537230ef978aaf42d72073dc95ba598db8db6c)) +* update doc readme ([c02c411](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/c02c411ffba9fc7906fcc7664d0ce841e0e2fb54)) +* updated readmes ([bfdbea0](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/bfdbea038918d79df2e3e9442e25d5f08bbccbbc)) + + +### Bug Fixes + +* .toml file ([e719881](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/e7198817d8dac802361ab84bc4d5d961fb926767)) +* add enw timeout ([46ebd9d](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/46ebd9dc9897ca2ef9460a3e46b3a24abe90f943)) +* add new python compatibility ([77b67f6](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/77b67f646d75abd3a558b40cb31c52c12cc7182e)) +* add revert ([09257e0](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/09257e08246d8aee96b3944ac14cc14b88e5f818)) +* come back to py 3.10 ([26d3a75](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/26d3a75ed973590e21d55c985bf71f3905a3ac0e)) +* fixed configuration for ignored files ([bc08dcb](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/bc08dcb21536a146fd941119931bc8e89e8e42c6)) +* fixed schema example ([365378a](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/365378a0c8c9125800ed6d74629d87776cf484a0)) +* houses examples and typos ([c596c44](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/c596c448e334a76444ecf3ee738ec275fd5316fa)) +* improve api desc ([62243f8](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/62243f84384ae238c0bd0c48abc76a6b99376c74)) +* logger working properly now ([9712d4c](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/9712d4c39eea860f813e86a5e2ffc14db6d3a655)) +* make timeout optional ([49b8e4b](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/49b8e4b8d3aa637bfd28a59e47cd1f5efad91075)) +* minor fix version ([0b972c6](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/0b972c69a9ea843d8ec89327f35c287b0d7a2bb4)) +* pyproject ([5d6a9ee](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/5d6a9eed262d1041eea3110fbaa1729f2c16855c)) +* pyproject ([2440f7f](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/2440f7f2a5179c6e3a86faf4eefa1d5edf7524c8)) +* pyproject.toml ([c6e6c6e](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/c6e6c6e33cd189bd78d7366dd570ee1e4d8c2c68)) +* pyproject.toml ([e8aed70](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/e8aed7011c1a65eca2909df88a804179a04bdd96)) +* python version ([24366b0](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/24366b08eefe0789da9a0ccafb8058e8744ee58b)) +* readme js sdk ([3c2178e](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/3c2178e04e873885abc8aca0312f5a4a1dd9cdd0)) +* removed wrong information ([88a2f50](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/88a2f509dc34ad69f41fe6d13f31de191895bc1a)) +* semanti release 2 ([97d9977](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/97d9977f2d757c6b23fa4c406433c817bf367bcb)) +* semantic release ([c0f3bbf](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/c0f3bbf5af127f5e5fced88bc39a86af4cb52a43)) +* sync client ([690e87b](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/690e87b52505f12da172147a78007497f6edf54c)) +* the "workspace" key has been removed because it was conflicting with the package.json file in the scrapegraph-js folder. ([1299173](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/129917377b6a685d769a480b717bf980d3199833)) +* timeout ([589aa49](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/589aa49d4434f7112a840d178e5e48918b7799e1)) +* updated comment ([8250818](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/825081883940bc1caa37f4f13e10f710770aeb9c)) +* updated env variable loading ([2643f11](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/2643f11c968f0daab26529d513f08c2817763b50)) +* updated hatchling version ([740933a](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/740933aff79a5873e6d1c633afcedb674d1f4cf0)) + + +### Chores + +* added dotenv pakage dependency ([2e9d93d](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/2e9d93d571c47c3b7aa789be811f53161387b08e)) +* added more information about the package ([97b8ff7](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/97b8ff749f7a588152629e246f690aad7ad348f1)) +* added Zod package dependency ([ee5738b](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/ee5738bd737cd07a553d148403a4bbb5e80e5be3)) +* changed pakage name ([9e9e138](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/9e9e138617658e068a1c77a4dbac24b4d550d42a)) +* fix _make_request not using it ([701a4c1](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/701a4c13bbe7e5d4ba9eae1846b0bd8abbbdb6b8)) +* fix pylint scripts ([5913d5f](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/5913d5f0d697196469f8ec952e1a65e1c7f49621)) +* fix pyproject version ([3567034](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/3567034e02e4dfab967248a5a4eaee426f145d6b)) +* fix semantic release, migrate to uv ([b6db205](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/b6db205ad5a90031bc658e65794e4dda2159fee2)) +* improved url validation ([83eac53](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/83eac530269a767e5469c4aded1656fe00a2cdc0)) +* refactor examples ([8e00846](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/8e008465f7280c53e2faab7a92f02871ffc5b867)) +* set up CI scripts ([f688bdc](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/f688bdc11746325582787fa3c1ffb429838f46b6)) +* set up eslint and prettier for code linting and formatting ([13cf1e5](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/13cf1e5c28ec739d2d35617bd57d7cf8203c3f7e)) +* sync repo ([b26b447](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/b26b447aaba0b4298d6bd319dc093a04d253704d)) +* **tests:** updated tests ([9149ce8](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/9149ce85a78b503098f80910c20de69831030378)) +* update SDK settings ([1e22194](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/1e2219430d800573547ea6a8d40ffa0e92e47d01)) +* update workflow scripts ([5ea9cac](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/5ea9cacb6758171283d96ff9aa1934c25af804f1)) + + +### Documentation + +* added an example of the smartScraper functionality using a schema ([cf2f28f](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/cf2f28fa029df0acb7058fde8239046d77ef0a8a)) +* added api reference ([6929a7a](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/6929a7adcc09f47a652cfd7ad7557314b52db9c0)) +* added api reference ([7b88876](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/7b88876facc2b37e4738797b6a18c65ca89f9aa0)) +* added cookbook reference ([e68c1bd](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/e68c1bd1268663a625441bc7f955a1d4514ac0ef)) +* added langchain-scrapegraph examples ([479dbdb](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/479dbdb1833a3ce6c2ce03eaf1400487ff534dd0)) +* added new image ([b052ddb](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/b052ddbe0d1a5ea182c54897c94d4c88fbc54ab8)) +* added open in colab badge ([c2fc1ef](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/c2fc1efc687623bd821468c19a102dbaed70bd4b)) +* added two langchain-scrapegraph examples ([8f3a87e](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/8f3a87e880f820f4453d564fec02ef02af3742b3)) +* added two new examples ([5fa2b42](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/5fa2b42685df565531cd7d2495e1d42e5c34ff90)) +* added wired langgraph react agent ([9f1e0cf](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/9f1e0cf72f4f84ee1f81439befaeace8c5c7ffa5)) +* added zillow example ([7fad92c](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/7fad92ca5e87cd9ecc60702e1599b2cff479af5c)) +* api reference ([855c2e5](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/855c2e51ebfaf7d8e4be008e8f22fdf66c0dc0e0)) +* **cookbook:** added two new examples ([f67769e](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/f67769e0ef0bba6fc4fd6908ec666b63ac2368b9)) +* fixed cookbook images and urls ([f860167](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/f8601674f686084a7df88e221475c014b40015b8)) +* github trending sdk ([320de37](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/320de37d2e8ec0d859ca91725c6cc35dab68e183)) +* improved examples ([a9c1fa5](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/a9c1fa5dcd7610b2b0c217d39fb2b77a67aa3fac)) +* improved main readme ([50fdf92](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/50fdf920e1d00e8f457138f9e68df74354696fc0)) +* link typo ([e1bfd6a](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/e1bfd6aa364b369c17457513f1c68e91376d0c68)) +* llama-index @VinciGit00 ([6de5eb2](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/6de5eb22490de2f5ff4075836bf1aca2e304ff8d)) +* research agent ([6e06afa](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/6e06afa9f8d5e9f05a38e605562ec10249216704)) +* updated new documentation urls ([1d0cb46](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/1d0cb46e5710707151ce227fa2043d5de5e92657)) +* updated precommit and installation guide ([c16705b](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/c16705b8f405f57d2cb1719099d4b566186a7257)) +* updated readme ([ee9efa6](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/ee9efa608b9a284861f712ab2a69d49da3d26523)) + + +### Styles + +* Improve formatting and style ([671161d](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/671161d5e53cb87a84624ae6e99494f91f08f236)) +* Improve formatting and style ([5c43dc4](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/5c43dc4801cacfc3441f6235fe1e73a388d46b06)) + + +### Refactors + +* code refactoring ([a2b57c7](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/a2b57c7e482dfb5c7c1a125d1684e0367088c83b)) +* code refactoring ([164131a](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/164131a2abe899bd151113bd84efa113306327c2)) +* code refactoring ([01ca238](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/01ca2384f098ecbb063ac4681e6d32f590a03f42)) +* improved code structure ([aa6a483](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/aa6a48332aa782c1cd395e8ada005b4d2f9a1bde)) +* renamed functions ([d39f14e](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/d39f14e344ef59e3a8e4f501a080ccbe1151abee)) +* update readme ([0669f52](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/0669f5219970079bbe7bde7502b4f55e5c3f5a45)) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..16cefb4 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,66 @@ +## Setting up the environment + +To set up the repository, run: + +```sh +$ ./scripts/bootstrap +$ ./scripts/lint +``` + +This will install all the required dependencies and build the SDK. + +You can also [install go 1.18+ manually](https://go.dev/doc/install). + +## Modifying/Adding code + +Most of the SDK is generated code. Modifications to code will be persisted between generations, but may +result in merge conflicts between manual patches and changes from the generator. The generator will never +modify the contents of the `lib/` and `examples/` directories. + +## Adding and running examples + +All files in the `examples/` directory are not modified by the generator and can be freely edited or added to. + +```go +# add an example to examples//main.go + +package main + +func main() { + // ... +} +``` + +```sh +$ go run ./examples/ +``` + +## Using the repository from source + +To use a local version of this library from source in another project, edit the `go.mod` with a replace +directive. This can be done through the CLI with the following: + +```sh +$ go mod edit -replace github.com/ScrapeGraphAI/scrapegraph-sdk=/path/to/scrapegraph-sdk +``` + +## Running tests + +Most tests require you to [set up a mock server](https://github.com/stoplightio/prism) against the OpenAPI spec to run the tests. + +```sh +# you will need npm installed +$ npx prism mock path/to/your/openapi.yml +``` + +```sh +$ ./scripts/test +``` + +## Formatting + +This library uses the standard gofmt code formatter: + +```sh +$ ./scripts/format +``` diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..998f39a --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2025 Scrapegraphai SDK + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 0000000..8695eb9 --- /dev/null +++ b/README.md @@ -0,0 +1,474 @@ +# Scrapegraphai SDK Go API Library + +Go Reference + +The Scrapegraphai SDK Go library provides convenient access to the Scrapegraphai SDK REST API +from applications written in Go. + +It is generated with [Stainless](https://www.stainless.com/). + +## Installation + + + +```go +import ( + "github.com/ScrapeGraphAI/scrapegraph-sdk" // imported as scrapegraphaisdk +) +``` + + + +Or to pin the version: + + + +```sh +go get -u 'github.com/ScrapeGraphAI/scrapegraph-sdk@v0.1.0-alpha.1' +``` + + + +## Requirements + +This library requires Go 1.18+. + +## Usage + +The full API of this library can be found in [api.md](api.md). + +```go +package main + +import ( + "context" + + "github.com/ScrapeGraphAI/scrapegraph-sdk" + "github.com/ScrapeGraphAI/scrapegraph-sdk/option" +) + +func main() { + client := scrapegraphaisdk.NewClient( + option.WithAPIKey("My API Key"), // defaults to os.LookupEnv("SCRAPEGRAPHAI_SDK_API_KEY") + ) + err := client.Credits.List(context.TODO()) + if err != nil { + panic(err.Error()) + } +} + +``` + +### Request fields + +The scrapegraphaisdk library uses the [`omitzero`](https://tip.golang.org/doc/go1.24#encodingjsonpkgencodingjson) +semantics from the Go 1.24+ `encoding/json` release for request fields. + +Required primitive fields (`int64`, `string`, etc.) feature the tag \`json:"...,required"\`. These +fields are always serialized, even their zero values. + +Optional primitive types are wrapped in a `param.Opt[T]`. These fields can be set with the provided constructors, `scrapegraphaisdk.String(string)`, `scrapegraphaisdk.Int(int64)`, etc. + +Any `param.Opt[T]`, map, slice, struct or string enum uses the +tag \`json:"...,omitzero"\`. Its zero value is considered omitted. + +The `param.IsOmitted(any)` function can confirm the presence of any `omitzero` field. + +```go +p := scrapegraphaisdk.ExampleParams{ + ID: "id_xxx", // required property + Name: scrapegraphaisdk.String("..."), // optional property + + Point: scrapegraphaisdk.Point{ + X: 0, // required field will serialize as 0 + Y: scrapegraphaisdk.Int(1), // optional field will serialize as 1 + // ... omitted non-required fields will not be serialized + }, + + Origin: scrapegraphaisdk.Origin{}, // the zero value of [Origin] is considered omitted +} +``` + +To send `null` instead of a `param.Opt[T]`, use `param.Null[T]()`. +To send `null` instead of a struct `T`, use `param.NullStruct[T]()`. + +```go +p.Name = param.Null[string]() // 'null' instead of string +p.Point = param.NullStruct[Point]() // 'null' instead of struct + +param.IsNull(p.Name) // true +param.IsNull(p.Point) // true +``` + +Request structs contain a `.SetExtraFields(map[string]any)` method which can send non-conforming +fields in the request body. Extra fields overwrite any struct fields with a matching +key. For security reasons, only use `SetExtraFields` with trusted data. + +To send a custom value instead of a struct, use `param.Override[T](value)`. + +```go +// In cases where the API specifies a given type, +// but you want to send something else, use [SetExtraFields]: +p.SetExtraFields(map[string]any{ + "x": 0.01, // send "x" as a float instead of int +}) + +// Send a number instead of an object +custom := param.Override[scrapegraphaisdk.FooParams](12) +``` + +### Request unions + +Unions are represented as a struct with fields prefixed by "Of" for each of it's variants, +only one field can be non-zero. The non-zero field will be serialized. + +Sub-properties of the union can be accessed via methods on the union struct. +These methods return a mutable pointer to the underlying data, if present. + +```go +// Only one field can be non-zero, use param.IsOmitted() to check if a field is set +type AnimalUnionParam struct { + OfCat *Cat `json:",omitzero,inline` + OfDog *Dog `json:",omitzero,inline` +} + +animal := AnimalUnionParam{ + OfCat: &Cat{ + Name: "Whiskers", + Owner: PersonParam{ + Address: AddressParam{Street: "3333 Coyote Hill Rd", Zip: 0}, + }, + }, +} + +// Mutating a field +if address := animal.GetOwner().GetAddress(); address != nil { + address.ZipCode = 94304 +} +``` + +### Response objects + +All fields in response structs are ordinary value types (not pointers or wrappers). +Response structs also include a special `JSON` field containing metadata about +each property. + +```go +type Animal struct { + Name string `json:"name,nullable"` + Owners int `json:"owners"` + Age int `json:"age"` + JSON struct { + Name respjson.Field + Owner respjson.Field + Age respjson.Field + ExtraFields map[string]respjson.Field + } `json:"-"` +} +``` + +To handle optional data, use the `.Valid()` method on the JSON field. +`.Valid()` returns true if a field is not `null`, not present, or couldn't be marshaled. + +If `.Valid()` is false, the corresponding field will simply be its zero value. + +```go +raw := `{"owners": 1, "name": null}` + +var res Animal +json.Unmarshal([]byte(raw), &res) + +// Accessing regular fields + +res.Owners // 1 +res.Name // "" +res.Age // 0 + +// Optional field checks + +res.JSON.Owners.Valid() // true +res.JSON.Name.Valid() // false +res.JSON.Age.Valid() // false + +// Raw JSON values + +res.JSON.Owners.Raw() // "1" +res.JSON.Name.Raw() == "null" // true +res.JSON.Name.Raw() == respjson.Null // true +res.JSON.Age.Raw() == "" // true +res.JSON.Age.Raw() == respjson.Omitted // true +``` + +These `.JSON` structs also include an `ExtraFields` map containing +any properties in the json response that were not specified +in the struct. This can be useful for API features not yet +present in the SDK. + +```go +body := res.JSON.ExtraFields["my_unexpected_field"].Raw() +``` + +### Response Unions + +In responses, unions are represented by a flattened struct containing all possible fields from each of the +object variants. +To convert it to a variant use the `.AsFooVariant()` method or the `.AsAny()` method if present. + +If a response value union contains primitive values, primitive fields will be alongside +the properties but prefixed with `Of` and feature the tag `json:"...,inline"`. + +```go +type AnimalUnion struct { + // From variants [Dog], [Cat] + Owner Person `json:"owner"` + // From variant [Dog] + DogBreed string `json:"dog_breed"` + // From variant [Cat] + CatBreed string `json:"cat_breed"` + // ... + + JSON struct { + Owner respjson.Field + // ... + } `json:"-"` +} + +// If animal variant +if animal.Owner.Address.ZipCode == "" { + panic("missing zip code") +} + +// Switch on the variant +switch variant := animal.AsAny().(type) { +case Dog: +case Cat: +default: + panic("unexpected type") +} +``` + +### RequestOptions + +This library uses the functional options pattern. Functions defined in the +`option` package return a `RequestOption`, which is a closure that mutates a +`RequestConfig`. These options can be supplied to the client or at individual +requests. For example: + +```go +client := scrapegraphaisdk.NewClient( + // Adds a header to every request made by the client + option.WithHeader("X-Some-Header", "custom_header_info"), +) + +client.Credits.List(context.TODO(), ..., + // Override the header + option.WithHeader("X-Some-Header", "some_other_custom_header_info"), + // Add an undocumented field to the request body, using sjson syntax + option.WithJSONSet("some.json.path", map[string]string{"my": "object"}), +) +``` + +The request option `option.WithDebugLog(nil)` may be helpful while debugging. + +See the [full list of request options](https://pkg.go.dev/github.com/ScrapeGraphAI/scrapegraph-sdk/option). + +### Pagination + +This library provides some conveniences for working with paginated list endpoints. + +You can use `.ListAutoPaging()` methods to iterate through items across all pages: + +Or you can use simple `.List()` methods to fetch a single page and receive a standard response object +with additional helper methods like `.GetNextPage()`, e.g.: + +### Errors + +When the API returns a non-success status code, we return an error with type +`*scrapegraphaisdk.Error`. This contains the `StatusCode`, `*http.Request`, and +`*http.Response` values of the request, as well as the JSON of the error body +(much like other response objects in the SDK). + +To handle errors, we recommend that you use the `errors.As` pattern: + +```go +err := client.Credits.List(context.TODO()) +if err != nil { + var apierr *scrapegraphaisdk.Error + if errors.As(err, &apierr) { + println(string(apierr.DumpRequest(true))) // Prints the serialized HTTP request + println(string(apierr.DumpResponse(true))) // Prints the serialized HTTP response + } + panic(err.Error()) // GET "/v1/credits": 400 Bad Request { ... } +} +``` + +When other errors occur, they are returned unwrapped; for example, +if HTTP transport fails, you might receive `*url.Error` wrapping `*net.OpError`. + +### Timeouts + +Requests do not time out by default; use context to configure a timeout for a request lifecycle. + +Note that if a request is [retried](#retries), the context timeout does not start over. +To set a per-retry timeout, use `option.WithRequestTimeout()`. + +```go +// This sets the timeout for the request, including all the retries. +ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) +defer cancel() +client.Credits.List( + ctx, + // This sets the per-retry timeout + option.WithRequestTimeout(20*time.Second), +) +``` + +### File uploads + +Request parameters that correspond to file uploads in multipart requests are typed as +`io.Reader`. The contents of the `io.Reader` will by default be sent as a multipart form +part with the file name of "anonymous_file" and content-type of "application/octet-stream". + +The file name and content-type can be customized by implementing `Name() string` or `ContentType() +string` on the run-time type of `io.Reader`. Note that `os.File` implements `Name() string`, so a +file returned by `os.Open` will be sent with the file name on disk. + +We also provide a helper `scrapegraphaisdk.File(reader io.Reader, filename string, contentType string)` +which can be used to wrap any `io.Reader` with the appropriate file name and content type. + +### Retries + +Certain errors will be automatically retried 2 times by default, with a short exponential backoff. +We retry by default all connection errors, 408 Request Timeout, 409 Conflict, 429 Rate Limit, +and >=500 Internal errors. + +You can use the `WithMaxRetries` option to configure or disable this: + +```go +// Configure the default for all requests: +client := scrapegraphaisdk.NewClient( + option.WithMaxRetries(0), // default is 2 +) + +// Override per-request: +client.Credits.List(context.TODO(), option.WithMaxRetries(5)) +``` + +### Accessing raw response data (e.g. response headers) + +You can access the raw HTTP response data by using the `option.WithResponseInto()` request option. This is useful when +you need to examine response headers, status codes, or other details. + +```go +// Create a variable to store the HTTP response +var response *http.Response +err := client.Credits.List(context.TODO(), option.WithResponseInto(&response)) +if err != nil { + // handle error +} +null + +fmt.Printf("Status Code: %d\n", response.StatusCode) +fmt.Printf("Headers: %+#v\n", response.Header) +``` + +### Making custom/undocumented requests + +This library is typed for convenient access to the documented API. If you need to access undocumented +endpoints, params, or response properties, the library can still be used. + +#### Undocumented endpoints + +To make requests to undocumented endpoints, you can use `client.Get`, `client.Post`, and other HTTP verbs. +`RequestOptions` on the client, such as retries, will be respected when making these requests. + +```go +var ( + // params can be an io.Reader, a []byte, an encoding/json serializable object, + // or a "…Params" struct defined in this library. + params map[string]any + + // result can be an []byte, *http.Response, a encoding/json deserializable object, + // or a model defined in this library. + result *http.Response +) +err := client.Post(context.Background(), "/unspecified", params, &result) +if err != nil { + … +} +``` + +#### Undocumented request params + +To make requests using undocumented parameters, you may use either the `option.WithQuerySet()` +or the `option.WithJSONSet()` methods. + +```go +params := FooNewParams{ + ID: "id_xxxx", + Data: FooNewParamsData{ + FirstName: scrapegraphaisdk.String("John"), + }, +} +client.Foo.New(context.Background(), params, option.WithJSONSet("data.last_name", "Doe")) +``` + +#### Undocumented response properties + +To access undocumented response properties, you may either access the raw JSON of the response as a string +with `result.JSON.RawJSON()`, or get the raw JSON of a particular field on the result with +`result.JSON.Foo.Raw()`. + +Any fields that are not present on the response struct will be saved and can be accessed by `result.JSON.ExtraFields()` which returns the extra fields as a `map[string]Field`. + +### Middleware + +We provide `option.WithMiddleware` which applies the given +middleware to requests. + +```go +func Logger(req *http.Request, next option.MiddlewareNext) (res *http.Response, err error) { + // Before the request + start := time.Now() + LogReq(req) + + // Forward the request to the next handler + res, err = next(req) + + // Handle stuff after the request + end := time.Now() + LogRes(res, err, start - end) + + return res, err +} + +client := scrapegraphaisdk.NewClient( + option.WithMiddleware(Logger), +) +``` + +When multiple middlewares are provided as variadic arguments, the middlewares +are applied left to right. If `option.WithMiddleware` is given +multiple times, for example first in the client then the method, the +middleware in the client will run first and the middleware given in the method +will run next. + +You may also replace the default `http.Client` with +`option.WithHTTPClient(client)`. Only one http client is +accepted (this overwrites any previous client) and receives requests after any +middleware has been applied. + +## Semantic versioning + +This package generally follows [SemVer](https://semver.org/spec/v2.0.0.html) conventions, though certain backwards-incompatible changes may be released as minor versions: + +1. Changes to library internals which are technically public but not intended or documented for external use. _(Please open a GitHub issue to let us know if you are relying on such internals.)_ +2. Changes that we do not expect to impact the vast majority of users in practice. + +We take backwards-compatibility seriously and work hard to ensure you can rely on a smooth upgrade experience. + +We are keen for your feedback; please open an [issue](https://www.github.com/ScrapeGraphAI/scrapegraph-sdk/issues) with questions, bugs, or suggestions. + +## Contributing + +See [the contributing documentation](./CONTRIBUTING.md). diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 0000000..83a5278 --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,23 @@ +# Security Policy + +## Reporting Security Issues + +This SDK is generated by [Stainless Software Inc](http://stainless.com). Stainless takes security seriously, and encourages you to report any security vulnerability promptly so that appropriate action can be taken. + +To report a security issue, please contact the Stainless team at security@stainless.com. + +## Responsible Disclosure + +We appreciate the efforts of security researchers and individuals who help us maintain the security of +SDKs we generate. If you believe you have found a security vulnerability, please adhere to responsible +disclosure practices by allowing us a reasonable amount of time to investigate and address the issue +before making any information public. + +## Reporting Non-SDK Related Security Issues + +If you encounter security issues that are not directly related to SDKs but pertain to the services +or products provided by Scrapegraphai SDK, please follow the respective company's security reporting guidelines. + +--- + +Thank you for helping us keep the SDKs and systems they interact with secure. diff --git a/aliases.go b/aliases.go new file mode 100644 index 0000000..f2abf17 --- /dev/null +++ b/aliases.go @@ -0,0 +1,16 @@ +// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +package scrapegraphaisdk + +import ( + "github.com/ScrapeGraphAI/scrapegraph-sdk/internal/apierror" + "github.com/ScrapeGraphAI/scrapegraph-sdk/packages/param" +) + +// aliased to make [param.APIUnion] private when embedding +type paramUnion = param.APIUnion + +// aliased to make [param.APIObject] private when embedding +type paramObj = param.APIObject + +type Error = apierror.Error diff --git a/api.md b/api.md new file mode 100644 index 0000000..4de19c9 --- /dev/null +++ b/api.md @@ -0,0 +1,57 @@ +# Credits + +Methods: + +- client.Credits.List(ctx context.Context) error + +# Validate + +Methods: + +- client.Validate.Check(ctx context.Context) error + +# Feedback + +Methods: + +- client.Feedback.New(ctx context.Context) error + +# Smartscraper + +Methods: + +- client.Smartscraper.New(ctx context.Context) error +- client.Smartscraper.Get(ctx context.Context, requestID string) error + +# Searchscraper + +Methods: + +- client.Searchscraper.New(ctx context.Context) error +- client.Searchscraper.Get(ctx context.Context, requestID string) error + +# Markdownify + +Methods: + +- client.Markdownify.New(ctx context.Context) error +- client.Markdownify.Get(ctx context.Context, requestID string) error + +# GenerateSchema + +Methods: + +- client.GenerateSchema.Get(ctx context.Context, requestID string) error + +# Smartcrawler + +Methods: + +- client.Smartcrawler.New(ctx context.Context) error +- client.Smartcrawler.Get(ctx context.Context, sessionID string) error + +## Sessions + +Methods: + +- client.Smartcrawler.Sessions.List(ctx context.Context) error diff --git a/client.go b/client.go new file mode 100644 index 0000000..f7d4172 --- /dev/null +++ b/client.go @@ -0,0 +1,130 @@ +// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +package scrapegraphaisdk + +import ( + "context" + "net/http" + "os" + + "github.com/ScrapeGraphAI/scrapegraph-sdk/internal/requestconfig" + "github.com/ScrapeGraphAI/scrapegraph-sdk/option" +) + +// Client creates a struct with services and top level methods that help with +// interacting with the scrapegraphai-sdk API. You should not instantiate this +// client directly, and instead use the [NewClient] method instead. +type Client struct { + Options []option.RequestOption + Credits CreditService + Validate ValidateService + Feedback FeedbackService + Smartscraper SmartscraperService + Searchscraper SearchscraperService + Markdownify MarkdownifyService + GenerateSchema GenerateSchemaService + Smartcrawler SmartcrawlerService +} + +// DefaultClientOptions read from the environment (SCRAPEGRAPHAI_SDK_API_KEY, +// SCRAPEGRAPHAI_SDK_BASE_URL). This should be used to initialize new clients. +func DefaultClientOptions() []option.RequestOption { + defaults := []option.RequestOption{option.WithEnvironmentProduction()} + if o, ok := os.LookupEnv("SCRAPEGRAPHAI_SDK_BASE_URL"); ok { + defaults = append(defaults, option.WithBaseURL(o)) + } + if o, ok := os.LookupEnv("SCRAPEGRAPHAI_SDK_API_KEY"); ok { + defaults = append(defaults, option.WithAPIKey(o)) + } + return defaults +} + +// NewClient generates a new client with the default option read from the +// environment (SCRAPEGRAPHAI_SDK_API_KEY, SCRAPEGRAPHAI_SDK_BASE_URL). The option +// passed in as arguments are applied after these default arguments, and all option +// will be passed down to the services and requests that this client makes. +func NewClient(opts ...option.RequestOption) (r Client) { + opts = append(DefaultClientOptions(), opts...) + + r = Client{Options: opts} + + r.Credits = NewCreditService(opts...) + r.Validate = NewValidateService(opts...) + r.Feedback = NewFeedbackService(opts...) + r.Smartscraper = NewSmartscraperService(opts...) + r.Searchscraper = NewSearchscraperService(opts...) + r.Markdownify = NewMarkdownifyService(opts...) + r.GenerateSchema = NewGenerateSchemaService(opts...) + r.Smartcrawler = NewSmartcrawlerService(opts...) + + return +} + +// Execute makes a request with the given context, method, URL, request params, +// response, and request options. This is useful for hitting undocumented endpoints +// while retaining the base URL, auth, retries, and other options from the client. +// +// If a byte slice or an [io.Reader] is supplied to params, it will be used as-is +// for the request body. +// +// The params is by default serialized into the body using [encoding/json]. If your +// type implements a MarshalJSON function, it will be used instead to serialize the +// request. If a URLQuery method is implemented, the returned [url.Values] will be +// used as query strings to the url. +// +// If your params struct uses [param.Field], you must provide either [MarshalJSON], +// [URLQuery], and/or [MarshalForm] functions. It is undefined behavior to use a +// struct uses [param.Field] without specifying how it is serialized. +// +// Any "…Params" object defined in this library can be used as the request +// argument. Note that 'path' arguments will not be forwarded into the url. +// +// The response body will be deserialized into the res variable, depending on its +// type: +// +// - A pointer to a [*http.Response] is populated by the raw response. +// - A pointer to a byte array will be populated with the contents of the request +// body. +// - A pointer to any other type uses this library's default JSON decoding, which +// respects UnmarshalJSON if it is defined on the type. +// - A nil value will not read the response body. +// +// For even greater flexibility, see [option.WithResponseInto] and +// [option.WithResponseBodyInto]. +func (r *Client) Execute(ctx context.Context, method string, path string, params any, res any, opts ...option.RequestOption) error { + opts = append(r.Options, opts...) + return requestconfig.ExecuteNewRequest(ctx, method, path, params, res, opts...) +} + +// Get makes a GET request with the given URL, params, and optionally deserializes +// to a response. See [Execute] documentation on the params and response. +func (r *Client) Get(ctx context.Context, path string, params any, res any, opts ...option.RequestOption) error { + return r.Execute(ctx, http.MethodGet, path, params, res, opts...) +} + +// Post makes a POST request with the given URL, params, and optionally +// deserializes to a response. See [Execute] documentation on the params and +// response. +func (r *Client) Post(ctx context.Context, path string, params any, res any, opts ...option.RequestOption) error { + return r.Execute(ctx, http.MethodPost, path, params, res, opts...) +} + +// Put makes a PUT request with the given URL, params, and optionally deserializes +// to a response. See [Execute] documentation on the params and response. +func (r *Client) Put(ctx context.Context, path string, params any, res any, opts ...option.RequestOption) error { + return r.Execute(ctx, http.MethodPut, path, params, res, opts...) +} + +// Patch makes a PATCH request with the given URL, params, and optionally +// deserializes to a response. See [Execute] documentation on the params and +// response. +func (r *Client) Patch(ctx context.Context, path string, params any, res any, opts ...option.RequestOption) error { + return r.Execute(ctx, http.MethodPatch, path, params, res, opts...) +} + +// Delete makes a DELETE request with the given URL, params, and optionally +// deserializes to a response. See [Execute] documentation on the params and +// response. +func (r *Client) Delete(ctx context.Context, path string, params any, res any, opts ...option.RequestOption) error { + return r.Execute(ctx, http.MethodDelete, path, params, res, opts...) +} diff --git a/client_test.go b/client_test.go new file mode 100644 index 0000000..24cb4ea --- /dev/null +++ b/client_test.go @@ -0,0 +1,243 @@ +// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +package scrapegraphaisdk_test + +import ( + "context" + "fmt" + "net/http" + "reflect" + "testing" + "time" + + "github.com/ScrapeGraphAI/scrapegraph-sdk" + "github.com/ScrapeGraphAI/scrapegraph-sdk/internal" + "github.com/ScrapeGraphAI/scrapegraph-sdk/option" +) + +type closureTransport struct { + fn func(req *http.Request) (*http.Response, error) +} + +func (t *closureTransport) RoundTrip(req *http.Request) (*http.Response, error) { + return t.fn(req) +} + +func TestUserAgentHeader(t *testing.T) { + var userAgent string + client := scrapegraphaisdk.NewClient( + option.WithAPIKey("My API Key"), + option.WithHTTPClient(&http.Client{ + Transport: &closureTransport{ + fn: func(req *http.Request) (*http.Response, error) { + userAgent = req.Header.Get("User-Agent") + return &http.Response{ + StatusCode: http.StatusOK, + }, nil + }, + }, + }), + ) + client.Credits.List(context.Background()) + if userAgent != fmt.Sprintf("ScrapegraphaiSDK/Go %s", internal.PackageVersion) { + t.Errorf("Expected User-Agent to be correct, but got: %#v", userAgent) + } +} + +func TestRetryAfter(t *testing.T) { + retryCountHeaders := make([]string, 0) + client := scrapegraphaisdk.NewClient( + option.WithAPIKey("My API Key"), + option.WithHTTPClient(&http.Client{ + Transport: &closureTransport{ + fn: func(req *http.Request) (*http.Response, error) { + retryCountHeaders = append(retryCountHeaders, req.Header.Get("X-Stainless-Retry-Count")) + return &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{ + http.CanonicalHeaderKey("Retry-After"): []string{"0.1"}, + }, + }, nil + }, + }, + }), + ) + err := client.Credits.List(context.Background()) + if err == nil { + t.Error("Expected there to be a cancel error") + } + + attempts := len(retryCountHeaders) + if attempts != 3 { + t.Errorf("Expected %d attempts, got %d", 3, attempts) + } + + expectedRetryCountHeaders := []string{"0", "1", "2"} + if !reflect.DeepEqual(retryCountHeaders, expectedRetryCountHeaders) { + t.Errorf("Expected %v retry count headers, got %v", expectedRetryCountHeaders, retryCountHeaders) + } +} + +func TestDeleteRetryCountHeader(t *testing.T) { + retryCountHeaders := make([]string, 0) + client := scrapegraphaisdk.NewClient( + option.WithAPIKey("My API Key"), + option.WithHTTPClient(&http.Client{ + Transport: &closureTransport{ + fn: func(req *http.Request) (*http.Response, error) { + retryCountHeaders = append(retryCountHeaders, req.Header.Get("X-Stainless-Retry-Count")) + return &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{ + http.CanonicalHeaderKey("Retry-After"): []string{"0.1"}, + }, + }, nil + }, + }, + }), + option.WithHeaderDel("X-Stainless-Retry-Count"), + ) + err := client.Credits.List(context.Background()) + if err == nil { + t.Error("Expected there to be a cancel error") + } + + expectedRetryCountHeaders := []string{"", "", ""} + if !reflect.DeepEqual(retryCountHeaders, expectedRetryCountHeaders) { + t.Errorf("Expected %v retry count headers, got %v", expectedRetryCountHeaders, retryCountHeaders) + } +} + +func TestOverwriteRetryCountHeader(t *testing.T) { + retryCountHeaders := make([]string, 0) + client := scrapegraphaisdk.NewClient( + option.WithAPIKey("My API Key"), + option.WithHTTPClient(&http.Client{ + Transport: &closureTransport{ + fn: func(req *http.Request) (*http.Response, error) { + retryCountHeaders = append(retryCountHeaders, req.Header.Get("X-Stainless-Retry-Count")) + return &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{ + http.CanonicalHeaderKey("Retry-After"): []string{"0.1"}, + }, + }, nil + }, + }, + }), + option.WithHeader("X-Stainless-Retry-Count", "42"), + ) + err := client.Credits.List(context.Background()) + if err == nil { + t.Error("Expected there to be a cancel error") + } + + expectedRetryCountHeaders := []string{"42", "42", "42"} + if !reflect.DeepEqual(retryCountHeaders, expectedRetryCountHeaders) { + t.Errorf("Expected %v retry count headers, got %v", expectedRetryCountHeaders, retryCountHeaders) + } +} + +func TestRetryAfterMs(t *testing.T) { + attempts := 0 + client := scrapegraphaisdk.NewClient( + option.WithAPIKey("My API Key"), + option.WithHTTPClient(&http.Client{ + Transport: &closureTransport{ + fn: func(req *http.Request) (*http.Response, error) { + attempts++ + return &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{ + http.CanonicalHeaderKey("Retry-After-Ms"): []string{"100"}, + }, + }, nil + }, + }, + }), + ) + err := client.Credits.List(context.Background()) + if err == nil { + t.Error("Expected there to be a cancel error") + } + if want := 3; attempts != want { + t.Errorf("Expected %d attempts, got %d", want, attempts) + } +} + +func TestContextCancel(t *testing.T) { + client := scrapegraphaisdk.NewClient( + option.WithAPIKey("My API Key"), + option.WithHTTPClient(&http.Client{ + Transport: &closureTransport{ + fn: func(req *http.Request) (*http.Response, error) { + <-req.Context().Done() + return nil, req.Context().Err() + }, + }, + }), + ) + cancelCtx, cancel := context.WithCancel(context.Background()) + cancel() + err := client.Credits.List(cancelCtx) + if err == nil { + t.Error("Expected there to be a cancel error") + } +} + +func TestContextCancelDelay(t *testing.T) { + client := scrapegraphaisdk.NewClient( + option.WithAPIKey("My API Key"), + option.WithHTTPClient(&http.Client{ + Transport: &closureTransport{ + fn: func(req *http.Request) (*http.Response, error) { + <-req.Context().Done() + return nil, req.Context().Err() + }, + }, + }), + ) + cancelCtx, cancel := context.WithTimeout(context.Background(), 2*time.Millisecond) + defer cancel() + err := client.Credits.List(cancelCtx) + if err == nil { + t.Error("expected there to be a cancel error") + } +} + +func TestContextDeadline(t *testing.T) { + testTimeout := time.After(3 * time.Second) + testDone := make(chan struct{}) + + deadline := time.Now().Add(100 * time.Millisecond) + deadlineCtx, cancel := context.WithDeadline(context.Background(), deadline) + defer cancel() + + go func() { + client := scrapegraphaisdk.NewClient( + option.WithAPIKey("My API Key"), + option.WithHTTPClient(&http.Client{ + Transport: &closureTransport{ + fn: func(req *http.Request) (*http.Response, error) { + <-req.Context().Done() + return nil, req.Context().Err() + }, + }, + }), + ) + err := client.Credits.List(deadlineCtx) + if err == nil { + t.Error("expected there to be a deadline error") + } + close(testDone) + }() + + select { + case <-testTimeout: + t.Fatal("client didn't finish in time") + case <-testDone: + if diff := time.Since(deadline); diff < -30*time.Millisecond || 30*time.Millisecond < diff { + t.Fatalf("client did not return within 30ms of context deadline, got %s", diff) + } + } +} diff --git a/credit.go b/credit.go new file mode 100644 index 0000000..7699946 --- /dev/null +++ b/credit.go @@ -0,0 +1,39 @@ +// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +package scrapegraphaisdk + +import ( + "context" + "net/http" + + "github.com/ScrapeGraphAI/scrapegraph-sdk/internal/requestconfig" + "github.com/ScrapeGraphAI/scrapegraph-sdk/option" +) + +// CreditService contains methods and other services that help with interacting +// with the scrapegraphai-sdk API. +// +// Note, unlike clients, this service does not read variables from the environment +// automatically. You should not instantiate this service directly, and instead use +// the [NewCreditService] method instead. +type CreditService struct { + Options []option.RequestOption +} + +// NewCreditService generates a new service that applies the given options to each +// request. These options are applied after the parent client's options (if there +// is one), and before any request-specific options. +func NewCreditService(opts ...option.RequestOption) (r CreditService) { + r = CreditService{} + r.Options = opts + return +} + +// GET /credits +func (r *CreditService) List(ctx context.Context, opts ...option.RequestOption) (err error) { + opts = append(r.Options[:], opts...) + opts = append([]option.RequestOption{option.WithHeader("Accept", "")}, opts...) + path := "v1/credits" + err = requestconfig.ExecuteNewRequest(ctx, http.MethodGet, path, nil, nil, opts...) + return +} diff --git a/credit_test.go b/credit_test.go new file mode 100644 index 0000000..63a43fd --- /dev/null +++ b/credit_test.go @@ -0,0 +1,37 @@ +// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +package scrapegraphaisdk_test + +import ( + "context" + "errors" + "os" + "testing" + + "github.com/ScrapeGraphAI/scrapegraph-sdk" + "github.com/ScrapeGraphAI/scrapegraph-sdk/internal/testutil" + "github.com/ScrapeGraphAI/scrapegraph-sdk/option" +) + +func TestCreditList(t *testing.T) { + t.Skip("skipped: tests are disabled for the time being") + baseURL := "http://localhost:4010" + if envURL, ok := os.LookupEnv("TEST_API_BASE_URL"); ok { + baseURL = envURL + } + if !testutil.CheckTestServer(t, baseURL) { + return + } + client := scrapegraphaisdk.NewClient( + option.WithBaseURL(baseURL), + option.WithAPIKey("My API Key"), + ) + err := client.Credits.List(context.TODO()) + if err != nil { + var apierr *scrapegraphaisdk.Error + if errors.As(err, &apierr) { + t.Log(string(apierr.DumpRequest(true))) + } + t.Fatalf("err should be nil: %s", err.Error()) + } +} diff --git a/examples/.env.examples b/examples/.env.examples deleted file mode 100644 index a3ba0c3..0000000 --- a/examples/.env.examples +++ /dev/null @@ -1 +0,0 @@ -SCRAPEGRAPH_API_KEY="YOUR_SCRAPEGRAPH_API_KEY_HERE" \ No newline at end of file diff --git a/examples/.keep b/examples/.keep new file mode 100644 index 0000000..d8c73e9 --- /dev/null +++ b/examples/.keep @@ -0,0 +1,4 @@ +File generated from our OpenAPI spec by Stainless. + +This directory can be used to store example files demonstrating usage of this SDK. +It is ignored by Stainless code generation and its content (other than this keep file) won't be touched. \ No newline at end of file diff --git a/examples/scrape_example.py b/examples/scrape_example.py deleted file mode 100644 index c398ccb..0000000 --- a/examples/scrape_example.py +++ /dev/null @@ -1,39 +0,0 @@ -import os -from dotenv import load_dotenv -from scrapegraphaiapisdk.scrape import scrape -from pydantic import BaseModel -from typing import List - -# Load environment variables from .env file -load_dotenv() - -class Product(BaseModel): - name: str - price: float - description: str - -class ProductList(BaseModel): - products: List[Product] - -def main(): - # Get API key from environment variables - api_key = os.getenv("SCRAPEGRAPH_API_KEY") - - # URL to scrape - url = "https://example.com/products" - - # Natural language prompt - prompt = "Extract all products from this page including their names, prices, and descriptions" - - # Create schema - schema = ProductList - - # Make the request - try: - result = scrape(api_key, url, prompt, schema) - print(f"Scraped data: {result}") - except Exception as e: - print(f"Error occurred: {e}") - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/examples/status_example.py b/examples/status_example.py deleted file mode 100644 index 8239599..0000000 --- a/examples/status_example.py +++ /dev/null @@ -1,20 +0,0 @@ -import os -from dotenv import load_dotenv -from scrapegraphaiapisdk.status import status - -# Load environment variables from .env file -load_dotenv() - -def main(): - # Get API key from environment variables - api_key = os.getenv("SCRAPEGRAPH_API_KEY") - - # Check API status - try: - result = status(api_key) - print(f"API Status: {result}") - except Exception as e: - print(f"Error occurred: {e}") - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/feedback.go b/feedback.go new file mode 100644 index 0000000..344e784 --- /dev/null +++ b/feedback.go @@ -0,0 +1,39 @@ +// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +package scrapegraphaisdk + +import ( + "context" + "net/http" + + "github.com/ScrapeGraphAI/scrapegraph-sdk/internal/requestconfig" + "github.com/ScrapeGraphAI/scrapegraph-sdk/option" +) + +// FeedbackService contains methods and other services that help with interacting +// with the scrapegraphai-sdk API. +// +// Note, unlike clients, this service does not read variables from the environment +// automatically. You should not instantiate this service directly, and instead use +// the [NewFeedbackService] method instead. +type FeedbackService struct { + Options []option.RequestOption +} + +// NewFeedbackService generates a new service that applies the given options to +// each request. These options are applied after the parent client's options (if +// there is one), and before any request-specific options. +func NewFeedbackService(opts ...option.RequestOption) (r FeedbackService) { + r = FeedbackService{} + r.Options = opts + return +} + +// POST /feedback +func (r *FeedbackService) New(ctx context.Context, opts ...option.RequestOption) (err error) { + opts = append(r.Options[:], opts...) + opts = append([]option.RequestOption{option.WithHeader("Accept", "")}, opts...) + path := "v1/feedback" + err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, nil, nil, opts...) + return +} diff --git a/feedback_test.go b/feedback_test.go new file mode 100644 index 0000000..2ddcb8d --- /dev/null +++ b/feedback_test.go @@ -0,0 +1,37 @@ +// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +package scrapegraphaisdk_test + +import ( + "context" + "errors" + "os" + "testing" + + "github.com/ScrapeGraphAI/scrapegraph-sdk" + "github.com/ScrapeGraphAI/scrapegraph-sdk/internal/testutil" + "github.com/ScrapeGraphAI/scrapegraph-sdk/option" +) + +func TestFeedbackNew(t *testing.T) { + t.Skip("skipped: tests are disabled for the time being") + baseURL := "http://localhost:4010" + if envURL, ok := os.LookupEnv("TEST_API_BASE_URL"); ok { + baseURL = envURL + } + if !testutil.CheckTestServer(t, baseURL) { + return + } + client := scrapegraphaisdk.NewClient( + option.WithBaseURL(baseURL), + option.WithAPIKey("My API Key"), + ) + err := client.Feedback.New(context.TODO()) + if err != nil { + var apierr *scrapegraphaisdk.Error + if errors.As(err, &apierr) { + t.Log(string(apierr.DumpRequest(true))) + } + t.Fatalf("err should be nil: %s", err.Error()) + } +} diff --git a/field.go b/field.go new file mode 100644 index 0000000..6a2e2b8 --- /dev/null +++ b/field.go @@ -0,0 +1,45 @@ +package scrapegraphaisdk + +import ( + "github.com/ScrapeGraphAI/scrapegraph-sdk/packages/param" + "io" + "time" +) + +func String(s string) param.Opt[string] { return param.NewOpt(s) } +func Int(i int64) param.Opt[int64] { return param.NewOpt(i) } +func Bool(b bool) param.Opt[bool] { return param.NewOpt(b) } +func Float(f float64) param.Opt[float64] { return param.NewOpt(f) } +func Time(t time.Time) param.Opt[time.Time] { return param.NewOpt(t) } + +func Opt[T comparable](v T) param.Opt[T] { return param.NewOpt(v) } +func Ptr[T any](v T) *T { return &v } + +func IntPtr(v int64) *int64 { return &v } +func BoolPtr(v bool) *bool { return &v } +func FloatPtr(v float64) *float64 { return &v } +func StringPtr(v string) *string { return &v } +func TimePtr(v time.Time) *time.Time { return &v } + +func File(rdr io.Reader, filename string, contentType string) file { + return file{rdr, filename, contentType} +} + +type file struct { + io.Reader + name string + contentType string +} + +func (f file) Filename() string { + if f.name != "" { + return f.name + } else if named, ok := f.Reader.(interface{ Name() string }); ok { + return named.Name() + } + return "" +} + +func (f file) ContentType() string { + return f.contentType +} diff --git a/generateschema.go b/generateschema.go new file mode 100644 index 0000000..fe89498 --- /dev/null +++ b/generateschema.go @@ -0,0 +1,45 @@ +// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +package scrapegraphaisdk + +import ( + "context" + "errors" + "fmt" + "net/http" + + "github.com/ScrapeGraphAI/scrapegraph-sdk/internal/requestconfig" + "github.com/ScrapeGraphAI/scrapegraph-sdk/option" +) + +// GenerateSchemaService contains methods and other services that help with +// interacting with the scrapegraphai-sdk API. +// +// Note, unlike clients, this service does not read variables from the environment +// automatically. You should not instantiate this service directly, and instead use +// the [NewGenerateSchemaService] method instead. +type GenerateSchemaService struct { + Options []option.RequestOption +} + +// NewGenerateSchemaService generates a new service that applies the given options +// to each request. These options are applied after the parent client's options (if +// there is one), and before any request-specific options. +func NewGenerateSchemaService(opts ...option.RequestOption) (r GenerateSchemaService) { + r = GenerateSchemaService{} + r.Options = opts + return +} + +// GET /generate_schema/{request_id} +func (r *GenerateSchemaService) Get(ctx context.Context, requestID string, opts ...option.RequestOption) (err error) { + opts = append(r.Options[:], opts...) + opts = append([]option.RequestOption{option.WithHeader("Accept", "")}, opts...) + if requestID == "" { + err = errors.New("missing required request_id parameter") + return + } + path := fmt.Sprintf("generate_schema/%s", requestID) + err = requestconfig.ExecuteNewRequest(ctx, http.MethodGet, path, nil, nil, opts...) + return +} diff --git a/generateschema_test.go b/generateschema_test.go new file mode 100644 index 0000000..ac193b8 --- /dev/null +++ b/generateschema_test.go @@ -0,0 +1,37 @@ +// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +package scrapegraphaisdk_test + +import ( + "context" + "errors" + "os" + "testing" + + "github.com/ScrapeGraphAI/scrapegraph-sdk" + "github.com/ScrapeGraphAI/scrapegraph-sdk/internal/testutil" + "github.com/ScrapeGraphAI/scrapegraph-sdk/option" +) + +func TestGenerateSchemaGet(t *testing.T) { + t.Skip("skipped: tests are disabled for the time being") + baseURL := "http://localhost:4010" + if envURL, ok := os.LookupEnv("TEST_API_BASE_URL"); ok { + baseURL = envURL + } + if !testutil.CheckTestServer(t, baseURL) { + return + } + client := scrapegraphaisdk.NewClient( + option.WithBaseURL(baseURL), + option.WithAPIKey("My API Key"), + ) + err := client.GenerateSchema.Get(context.TODO(), "request_id") + if err != nil { + var apierr *scrapegraphaisdk.Error + if errors.As(err, &apierr) { + t.Log(string(apierr.DumpRequest(true))) + } + t.Fatalf("err should be nil: %s", err.Error()) + } +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..bba510d --- /dev/null +++ b/go.mod @@ -0,0 +1,13 @@ +module github.com/ScrapeGraphAI/scrapegraph-sdk + +go 1.21 + +require ( + github.com/tidwall/gjson v1.14.4 + github.com/tidwall/sjson v1.2.5 +) + +require ( + github.com/tidwall/match v1.1.1 // indirect + github.com/tidwall/pretty v1.2.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..a70a5e0 --- /dev/null +++ b/go.sum @@ -0,0 +1,10 @@ +github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.14.4 h1:uo0p8EbA09J7RQaflQ1aBRffTR7xedD2bcIVSYxLnkM= +github.com/tidwall/gjson v1.14.4/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= +github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= +github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= +github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= diff --git a/internal/apierror/apierror.go b/internal/apierror/apierror.go new file mode 100644 index 0000000..e52e475 --- /dev/null +++ b/internal/apierror/apierror.go @@ -0,0 +1,50 @@ +// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +package apierror + +import ( + "fmt" + "net/http" + "net/http/httputil" + + "github.com/ScrapeGraphAI/scrapegraph-sdk/internal/apijson" + "github.com/ScrapeGraphAI/scrapegraph-sdk/packages/respjson" +) + +// Error represents an error that originates from the API, i.e. when a request is +// made and the API returns a response with a HTTP status code. Other errors are +// not wrapped by this SDK. +type Error struct { + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` + StatusCode int + Request *http.Request + Response *http.Response +} + +// Returns the unmodified JSON received from the API +func (r Error) RawJSON() string { return r.JSON.raw } +func (r *Error) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +func (r *Error) Error() string { + // Attempt to re-populate the response body + return fmt.Sprintf("%s %q: %d %s %s", r.Request.Method, r.Request.URL, r.Response.StatusCode, http.StatusText(r.Response.StatusCode), r.JSON.raw) +} + +func (r *Error) DumpRequest(body bool) []byte { + if r.Request.GetBody != nil { + r.Request.Body, _ = r.Request.GetBody() + } + out, _ := httputil.DumpRequestOut(r.Request, body) + return out +} + +func (r *Error) DumpResponse(body bool) []byte { + out, _ := httputil.DumpResponse(r.Response, body) + return out +} diff --git a/internal/apiform/encoder.go b/internal/apiform/encoder.go new file mode 100644 index 0000000..5f266b1 --- /dev/null +++ b/internal/apiform/encoder.go @@ -0,0 +1,465 @@ +package apiform + +import ( + "fmt" + "io" + "mime/multipart" + "net/textproto" + "path" + "reflect" + "sort" + "strconv" + "strings" + "sync" + "time" + + "github.com/ScrapeGraphAI/scrapegraph-sdk/packages/param" +) + +var encoders sync.Map // map[encoderEntry]encoderFunc + +func Marshal(value any, writer *multipart.Writer) error { + e := &encoder{ + dateFormat: time.RFC3339, + arrayFmt: "comma", + } + return e.marshal(value, writer) +} + +func MarshalRoot(value any, writer *multipart.Writer) error { + e := &encoder{ + root: true, + dateFormat: time.RFC3339, + arrayFmt: "comma", + } + return e.marshal(value, writer) +} + +func MarshalWithSettings(value any, writer *multipart.Writer, arrayFormat string) error { + e := &encoder{ + arrayFmt: arrayFormat, + dateFormat: time.RFC3339, + } + return e.marshal(value, writer) +} + +type encoder struct { + arrayFmt string + dateFormat string + root bool +} + +type encoderFunc func(key string, value reflect.Value, writer *multipart.Writer) error + +type encoderField struct { + tag parsedStructTag + fn encoderFunc + idx []int +} + +type encoderEntry struct { + reflect.Type + dateFormat string + root bool +} + +func (e *encoder) marshal(value any, writer *multipart.Writer) error { + val := reflect.ValueOf(value) + if !val.IsValid() { + return nil + } + typ := val.Type() + enc := e.typeEncoder(typ) + return enc("", val, writer) +} + +func (e *encoder) typeEncoder(t reflect.Type) encoderFunc { + entry := encoderEntry{ + Type: t, + dateFormat: e.dateFormat, + root: e.root, + } + + if fi, ok := encoders.Load(entry); ok { + return fi.(encoderFunc) + } + + // To deal with recursive types, populate the map with an + // indirect func before we build it. This type waits on the + // real func (f) to be ready and then calls it. This indirect + // func is only used for recursive types. + var ( + wg sync.WaitGroup + f encoderFunc + ) + wg.Add(1) + fi, loaded := encoders.LoadOrStore(entry, encoderFunc(func(key string, v reflect.Value, writer *multipart.Writer) error { + wg.Wait() + return f(key, v, writer) + })) + if loaded { + return fi.(encoderFunc) + } + + // Compute the real encoder and replace the indirect func with it. + f = e.newTypeEncoder(t) + wg.Done() + encoders.Store(entry, f) + return f +} + +func (e *encoder) newTypeEncoder(t reflect.Type) encoderFunc { + if t.ConvertibleTo(reflect.TypeOf(time.Time{})) { + return e.newTimeTypeEncoder() + } + if t.Implements(reflect.TypeOf((*io.Reader)(nil)).Elem()) { + return e.newReaderTypeEncoder() + } + e.root = false + switch t.Kind() { + case reflect.Pointer: + inner := t.Elem() + + innerEncoder := e.typeEncoder(inner) + return func(key string, v reflect.Value, writer *multipart.Writer) error { + if !v.IsValid() || v.IsNil() { + return nil + } + return innerEncoder(key, v.Elem(), writer) + } + case reflect.Struct: + return e.newStructTypeEncoder(t) + case reflect.Slice, reflect.Array: + return e.newArrayTypeEncoder(t) + case reflect.Map: + return e.newMapEncoder(t) + case reflect.Interface: + return e.newInterfaceEncoder() + default: + return e.newPrimitiveTypeEncoder(t) + } +} + +func (e *encoder) newPrimitiveTypeEncoder(t reflect.Type) encoderFunc { + switch t.Kind() { + // Note that we could use `gjson` to encode these types but it would complicate our + // code more and this current code shouldn't cause any issues + case reflect.String: + return func(key string, v reflect.Value, writer *multipart.Writer) error { + return writer.WriteField(key, v.String()) + } + case reflect.Bool: + return func(key string, v reflect.Value, writer *multipart.Writer) error { + if v.Bool() { + return writer.WriteField(key, "true") + } + return writer.WriteField(key, "false") + } + case reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64: + return func(key string, v reflect.Value, writer *multipart.Writer) error { + return writer.WriteField(key, strconv.FormatInt(v.Int(), 10)) + } + case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return func(key string, v reflect.Value, writer *multipart.Writer) error { + return writer.WriteField(key, strconv.FormatUint(v.Uint(), 10)) + } + case reflect.Float32: + return func(key string, v reflect.Value, writer *multipart.Writer) error { + return writer.WriteField(key, strconv.FormatFloat(v.Float(), 'f', -1, 32)) + } + case reflect.Float64: + return func(key string, v reflect.Value, writer *multipart.Writer) error { + return writer.WriteField(key, strconv.FormatFloat(v.Float(), 'f', -1, 64)) + } + default: + return func(key string, v reflect.Value, writer *multipart.Writer) error { + return fmt.Errorf("unknown type received at primitive encoder: %s", t.String()) + } + } +} + +func arrayKeyEncoder(arrayFmt string) func(string, int) string { + var keyFn func(string, int) string + switch arrayFmt { + case "comma", "repeat": + keyFn = func(k string, _ int) string { return k } + case "brackets": + keyFn = func(key string, _ int) string { return key + "[]" } + case "indices:dots": + keyFn = func(k string, i int) string { + if k == "" { + return strconv.Itoa(i) + } + return k + "." + strconv.Itoa(i) + } + case "indices:brackets": + keyFn = func(k string, i int) string { + if k == "" { + return strconv.Itoa(i) + } + return k + "[" + strconv.Itoa(i) + "]" + } + } + return keyFn +} + +func (e *encoder) newArrayTypeEncoder(t reflect.Type) encoderFunc { + itemEncoder := e.typeEncoder(t.Elem()) + keyFn := arrayKeyEncoder(e.arrayFmt) + return func(key string, v reflect.Value, writer *multipart.Writer) error { + if keyFn == nil { + return fmt.Errorf("apiform: unsupported array format") + } + for i := 0; i < v.Len(); i++ { + err := itemEncoder(keyFn(key, i), v.Index(i), writer) + if err != nil { + return err + } + } + return nil + } +} + +func (e *encoder) newStructTypeEncoder(t reflect.Type) encoderFunc { + if t.Implements(reflect.TypeOf((*param.Optional)(nil)).Elem()) { + return e.newRichFieldTypeEncoder(t) + } + + for i := 0; i < t.NumField(); i++ { + if t.Field(i).Type == paramUnionType && t.Field(i).Anonymous { + return e.newStructUnionTypeEncoder(t) + } + } + + encoderFields := []encoderField{} + extraEncoder := (*encoderField)(nil) + + // This helper allows us to recursively collect field encoders into a flat + // array. The parameter `index` keeps track of the access patterns necessary + // to get to some field. + var collectEncoderFields func(r reflect.Type, index []int) + collectEncoderFields = func(r reflect.Type, index []int) { + for i := 0; i < r.NumField(); i++ { + idx := append(index, i) + field := t.FieldByIndex(idx) + if !field.IsExported() { + continue + } + // If this is an embedded struct, traverse one level deeper to extract + // the field and get their encoders as well. + if field.Anonymous { + collectEncoderFields(field.Type, idx) + continue + } + // If json tag is not present, then we skip, which is intentionally + // different behavior from the stdlib. + ptag, ok := parseFormStructTag(field) + if !ok { + continue + } + // We only want to support unexported field if they're tagged with + // `extras` because that field shouldn't be part of the public API. We + // also want to only keep the top level extras + if ptag.extras && len(index) == 0 { + extraEncoder = &encoderField{ptag, e.typeEncoder(field.Type.Elem()), idx} + continue + } + if ptag.name == "-" || ptag.name == "" { + continue + } + + dateFormat, ok := parseFormatStructTag(field) + oldFormat := e.dateFormat + if ok { + switch dateFormat { + case "date-time": + e.dateFormat = time.RFC3339 + case "date": + e.dateFormat = "2006-01-02" + } + } + + var encoderFn encoderFunc + if ptag.omitzero { + typeEncoderFn := e.typeEncoder(field.Type) + encoderFn = func(key string, value reflect.Value, writer *multipart.Writer) error { + if value.IsZero() { + return nil + } + return typeEncoderFn(key, value, writer) + } + } else { + encoderFn = e.typeEncoder(field.Type) + } + encoderFields = append(encoderFields, encoderField{ptag, encoderFn, idx}) + e.dateFormat = oldFormat + } + } + collectEncoderFields(t, []int{}) + + // Ensure deterministic output by sorting by lexicographic order + sort.Slice(encoderFields, func(i, j int) bool { + return encoderFields[i].tag.name < encoderFields[j].tag.name + }) + + return func(key string, value reflect.Value, writer *multipart.Writer) error { + if key != "" { + key = key + "." + } + + for _, ef := range encoderFields { + field := value.FieldByIndex(ef.idx) + err := ef.fn(key+ef.tag.name, field, writer) + if err != nil { + return err + } + } + + if extraEncoder != nil { + err := e.encodeMapEntries(key, value.FieldByIndex(extraEncoder.idx), writer) + if err != nil { + return err + } + } + + return nil + } +} + +var paramUnionType = reflect.TypeOf((*param.APIUnion)(nil)).Elem() + +func (e *encoder) newStructUnionTypeEncoder(t reflect.Type) encoderFunc { + var fieldEncoders []encoderFunc + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + if field.Type == paramUnionType && field.Anonymous { + fieldEncoders = append(fieldEncoders, nil) + continue + } + fieldEncoders = append(fieldEncoders, e.typeEncoder(field.Type)) + } + + return func(key string, value reflect.Value, writer *multipart.Writer) error { + for i := 0; i < t.NumField(); i++ { + if value.Field(i).Type() == paramUnionType { + continue + } + if !value.Field(i).IsZero() { + return fieldEncoders[i](key, value.Field(i), writer) + } + } + return fmt.Errorf("apiform: union %s has no field set", t.String()) + } +} + +func (e *encoder) newTimeTypeEncoder() encoderFunc { + format := e.dateFormat + return func(key string, value reflect.Value, writer *multipart.Writer) error { + return writer.WriteField(key, value.Convert(reflect.TypeOf(time.Time{})).Interface().(time.Time).Format(format)) + } +} + +func (e encoder) newInterfaceEncoder() encoderFunc { + return func(key string, value reflect.Value, writer *multipart.Writer) error { + value = value.Elem() + if !value.IsValid() { + return nil + } + return e.typeEncoder(value.Type())(key, value, writer) + } +} + +var quoteEscaper = strings.NewReplacer("\\", "\\\\", `"`, "\\\"") + +func escapeQuotes(s string) string { + return quoteEscaper.Replace(s) +} + +func (e *encoder) newReaderTypeEncoder() encoderFunc { + return func(key string, value reflect.Value, writer *multipart.Writer) error { + reader, ok := value.Convert(reflect.TypeOf((*io.Reader)(nil)).Elem()).Interface().(io.Reader) + if !ok { + return nil + } + filename := "anonymous_file" + contentType := "application/octet-stream" + if named, ok := reader.(interface{ Filename() string }); ok { + filename = named.Filename() + } else if named, ok := reader.(interface{ Name() string }); ok { + filename = path.Base(named.Name()) + } + if typed, ok := reader.(interface{ ContentType() string }); ok { + contentType = typed.ContentType() + } + + // Below is taken almost 1-for-1 from [multipart.CreateFormFile] + h := make(textproto.MIMEHeader) + h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, escapeQuotes(key), escapeQuotes(filename))) + h.Set("Content-Type", contentType) + filewriter, err := writer.CreatePart(h) + if err != nil { + return err + } + _, err = io.Copy(filewriter, reader) + return err + } +} + +// Given a []byte of json (may either be an empty object or an object that already contains entries) +// encode all of the entries in the map to the json byte array. +func (e *encoder) encodeMapEntries(key string, v reflect.Value, writer *multipart.Writer) error { + type mapPair struct { + key string + value reflect.Value + } + + if key != "" { + key = key + "." + } + + pairs := []mapPair{} + + iter := v.MapRange() + for iter.Next() { + if iter.Key().Type().Kind() == reflect.String { + pairs = append(pairs, mapPair{key: iter.Key().String(), value: iter.Value()}) + } else { + return fmt.Errorf("cannot encode a map with a non string key") + } + } + + // Ensure deterministic output + sort.Slice(pairs, func(i, j int) bool { + return pairs[i].key < pairs[j].key + }) + + elementEncoder := e.typeEncoder(v.Type().Elem()) + for _, p := range pairs { + err := elementEncoder(key+string(p.key), p.value, writer) + if err != nil { + return err + } + } + + return nil +} + +func (e *encoder) newMapEncoder(_ reflect.Type) encoderFunc { + return func(key string, value reflect.Value, writer *multipart.Writer) error { + return e.encodeMapEntries(key, value, writer) + } +} + +func WriteExtras(writer *multipart.Writer, extras map[string]any) (err error) { + for k, v := range extras { + str, ok := v.(string) + if !ok { + break + } + err = writer.WriteField(k, str) + if err != nil { + break + } + } + return +} diff --git a/internal/apiform/form.go b/internal/apiform/form.go new file mode 100644 index 0000000..5445116 --- /dev/null +++ b/internal/apiform/form.go @@ -0,0 +1,5 @@ +package apiform + +type Marshaler interface { + MarshalMultipart() ([]byte, string, error) +} diff --git a/internal/apiform/form_test.go b/internal/apiform/form_test.go new file mode 100644 index 0000000..fe83177 --- /dev/null +++ b/internal/apiform/form_test.go @@ -0,0 +1,560 @@ +package apiform + +import ( + "bytes" + "github.com/ScrapeGraphAI/scrapegraph-sdk/packages/param" + "io" + "mime/multipart" + "strings" + "testing" + "time" +) + +func P[T any](v T) *T { return &v } + +type Primitives struct { + A bool `form:"a"` + B int `form:"b"` + C uint `form:"c"` + D float64 `form:"d"` + E float32 `form:"e"` + F []int `form:"f"` +} + +// These aliases are necessary to bypass the cache. +// This only relevant during testing. +type int_ int +type PrimitivesBrackets struct { + F []int_ `form:"f"` +} + +type PrimitivePointers struct { + A *bool `form:"a"` + B *int `form:"b"` + C *uint `form:"c"` + D *float64 `form:"d"` + E *float32 `form:"e"` + F *[]int `form:"f"` +} + +type Slices struct { + Slice []Primitives `form:"slices"` +} + +type DateTime struct { + Date time.Time `form:"date" format:"date"` + DateTime time.Time `form:"date-time" format:"date-time"` +} + +type AdditionalProperties struct { + A bool `form:"a"` + Extras map[string]any `form:"-,extras"` +} + +type TypedAdditionalProperties struct { + A bool `form:"a"` + Extras map[string]int `form:"-,extras"` +} + +type EmbeddedStructs struct { + AdditionalProperties + A *int `form:"number2"` + Extras map[string]any `form:"-,extras"` +} + +type Recursive struct { + Name string `form:"name"` + Child *Recursive `form:"child"` +} + +type UnknownStruct struct { + Unknown any `form:"unknown"` +} + +type UnionStruct struct { + Union Union `form:"union" format:"date"` +} + +type Union interface { + union() +} + +type UnionInteger int64 + +func (UnionInteger) union() {} + +type UnionStructA struct { + Type string `form:"type"` + A string `form:"a"` + B string `form:"b"` +} + +func (UnionStructA) union() {} + +type UnionStructB struct { + Type string `form:"type"` + A string `form:"a"` +} + +func (UnionStructB) union() {} + +type UnionTime time.Time + +func (UnionTime) union() {} + +type ReaderStruct struct { + File io.Reader `form:"file"` +} + +type NamedEnum string + +const NamedEnumFoo NamedEnum = "foo" + +type StructUnionWrapper struct { + Union StructUnion `form:"union"` +} + +type StructUnion struct { + OfInt param.Opt[int64] `form:",omitzero,inline"` + OfString param.Opt[string] `form:",omitzero,inline"` + OfEnum param.Opt[NamedEnum] `form:",omitzero,inline"` + OfA UnionStructA `form:",omitzero,inline"` + OfB UnionStructB `form:",omitzero,inline"` + param.APIUnion +} + +var tests = map[string]struct { + buf string + val any +}{ + "file": { + buf: `--xxx +Content-Disposition: form-data; name="file"; filename="anonymous_file" +Content-Type: application/octet-stream + +some file contents... +--xxx-- +`, + val: ReaderStruct{ + File: io.Reader(bytes.NewBuffer([]byte("some file contents..."))), + }, + }, + "map_string": { + `--xxx +Content-Disposition: form-data; name="foo" + +bar +--xxx-- +`, + map[string]string{"foo": "bar"}, + }, + + "map_interface": { + `--xxx +Content-Disposition: form-data; name="a" + +1 +--xxx +Content-Disposition: form-data; name="b" + +str +--xxx +Content-Disposition: form-data; name="c" + +false +--xxx-- +`, + map[string]any{"a": float64(1), "b": "str", "c": false}, + }, + + "primitive_struct": { + `--xxx +Content-Disposition: form-data; name="a" + +false +--xxx +Content-Disposition: form-data; name="b" + +237628372683 +--xxx +Content-Disposition: form-data; name="c" + +654 +--xxx +Content-Disposition: form-data; name="d" + +9999.43 +--xxx +Content-Disposition: form-data; name="e" + +43.76 +--xxx +Content-Disposition: form-data; name="f.0" + +1 +--xxx +Content-Disposition: form-data; name="f.1" + +2 +--xxx +Content-Disposition: form-data; name="f.2" + +3 +--xxx +Content-Disposition: form-data; name="f.3" + +4 +--xxx-- +`, + Primitives{A: false, B: 237628372683, C: uint(654), D: 9999.43, E: 43.76, F: []int{1, 2, 3, 4}}, + }, + "primitive_struct,brackets": { + `--xxx +Content-Disposition: form-data; name="f[]" + +1 +--xxx +Content-Disposition: form-data; name="f[]" + +2 +--xxx +Content-Disposition: form-data; name="f[]" + +3 +--xxx +Content-Disposition: form-data; name="f[]" + +4 +--xxx-- +`, + PrimitivesBrackets{F: []int_{1, 2, 3, 4}}, + }, + + "slices": { + `--xxx +Content-Disposition: form-data; name="slices.0.a" + +false +--xxx +Content-Disposition: form-data; name="slices.0.b" + +237628372683 +--xxx +Content-Disposition: form-data; name="slices.0.c" + +654 +--xxx +Content-Disposition: form-data; name="slices.0.d" + +9999.43 +--xxx +Content-Disposition: form-data; name="slices.0.e" + +43.76 +--xxx +Content-Disposition: form-data; name="slices.0.f.0" + +1 +--xxx +Content-Disposition: form-data; name="slices.0.f.1" + +2 +--xxx +Content-Disposition: form-data; name="slices.0.f.2" + +3 +--xxx +Content-Disposition: form-data; name="slices.0.f.3" + +4 +--xxx-- +`, + Slices{ + Slice: []Primitives{{A: false, B: 237628372683, C: uint(654), D: 9999.43, E: 43.76, F: []int{1, 2, 3, 4}}}, + }, + }, + "primitive_pointer_struct": { + `--xxx +Content-Disposition: form-data; name="a" + +false +--xxx +Content-Disposition: form-data; name="b" + +237628372683 +--xxx +Content-Disposition: form-data; name="c" + +654 +--xxx +Content-Disposition: form-data; name="d" + +9999.43 +--xxx +Content-Disposition: form-data; name="e" + +43.76 +--xxx +Content-Disposition: form-data; name="f.0" + +1 +--xxx +Content-Disposition: form-data; name="f.1" + +2 +--xxx +Content-Disposition: form-data; name="f.2" + +3 +--xxx +Content-Disposition: form-data; name="f.3" + +4 +--xxx +Content-Disposition: form-data; name="f.4" + +5 +--xxx-- +`, + PrimitivePointers{ + A: P(false), + B: P(237628372683), + C: P(uint(654)), + D: P(9999.43), + E: P(float32(43.76)), + F: &[]int{1, 2, 3, 4, 5}, + }, + }, + + "datetime_struct": { + `--xxx +Content-Disposition: form-data; name="date" + +2006-01-02 +--xxx +Content-Disposition: form-data; name="date-time" + +2006-01-02T15:04:05Z +--xxx-- +`, + DateTime{ + Date: time.Date(2006, time.January, 2, 0, 0, 0, 0, time.UTC), + DateTime: time.Date(2006, time.January, 2, 15, 4, 5, 0, time.UTC), + }, + }, + + "additional_properties": { + `--xxx +Content-Disposition: form-data; name="a" + +true +--xxx +Content-Disposition: form-data; name="bar" + +value +--xxx +Content-Disposition: form-data; name="foo" + +true +--xxx-- +`, + AdditionalProperties{ + A: true, + Extras: map[string]any{ + "bar": "value", + "foo": true, + }, + }, + }, + + "recursive_struct": { + `--xxx +Content-Disposition: form-data; name="child.name" + +Alex +--xxx +Content-Disposition: form-data; name="name" + +Robert +--xxx-- +`, + Recursive{Name: "Robert", Child: &Recursive{Name: "Alex"}}, + }, + + "unknown_struct_number": { + `--xxx +Content-Disposition: form-data; name="unknown" + +12 +--xxx-- +`, + UnknownStruct{ + Unknown: 12., + }, + }, + + "unknown_struct_map": { + `--xxx +Content-Disposition: form-data; name="unknown.foo" + +bar +--xxx-- +`, + UnknownStruct{ + Unknown: map[string]any{ + "foo": "bar", + }, + }, + }, + + "struct_union_integer": { + `--xxx +Content-Disposition: form-data; name="union" + +12 +--xxx-- +`, + StructUnionWrapper{ + Union: StructUnion{OfInt: param.NewOpt[int64](12)}, + }, + }, + + "union_integer": { + `--xxx +Content-Disposition: form-data; name="union" + +12 +--xxx-- +`, + UnionStruct{ + Union: UnionInteger(12), + }, + }, + + "struct_union_struct_discriminated_a": { + `--xxx +Content-Disposition: form-data; name="union.a" + +foo +--xxx +Content-Disposition: form-data; name="union.b" + +bar +--xxx +Content-Disposition: form-data; name="union.type" + +typeA +--xxx-- +`, + StructUnionWrapper{ + Union: StructUnion{OfA: UnionStructA{ + Type: "typeA", + A: "foo", + B: "bar", + }}, + }, + }, + + "union_struct_discriminated_a": { + `--xxx +Content-Disposition: form-data; name="union.a" + +foo +--xxx +Content-Disposition: form-data; name="union.b" + +bar +--xxx +Content-Disposition: form-data; name="union.type" + +typeA +--xxx-- +`, + + UnionStruct{ + Union: UnionStructA{ + Type: "typeA", + A: "foo", + B: "bar", + }, + }, + }, + + "struct_union_struct_discriminated_b": { + `--xxx +Content-Disposition: form-data; name="union.a" + +foo +--xxx +Content-Disposition: form-data; name="union.type" + +typeB +--xxx-- +`, + StructUnionWrapper{ + Union: StructUnion{OfB: UnionStructB{ + Type: "typeB", + A: "foo", + }}, + }, + }, + + "union_struct_discriminated_b": { + `--xxx +Content-Disposition: form-data; name="union.a" + +foo +--xxx +Content-Disposition: form-data; name="union.type" + +typeB +--xxx-- +`, + UnionStruct{ + Union: UnionStructB{ + Type: "typeB", + A: "foo", + }, + }, + }, + + "union_struct_time": { + `--xxx +Content-Disposition: form-data; name="union" + +2010-05-23 +--xxx-- +`, + UnionStruct{ + Union: UnionTime(time.Date(2010, 05, 23, 0, 0, 0, 0, time.UTC)), + }, + }, +} + +func TestEncode(t *testing.T) { + for name, test := range tests { + t.Run(name, func(t *testing.T) { + buf := bytes.NewBuffer(nil) + writer := multipart.NewWriter(buf) + writer.SetBoundary("xxx") + + var arrayFmt string = "indices:dots" + if tags := strings.Split(name, ","); len(tags) > 1 { + arrayFmt = tags[1] + } + + err := MarshalWithSettings(test.val, writer, arrayFmt) + if err != nil { + t.Errorf("serialization of %v failed with error %v", test.val, err) + } + err = writer.Close() + if err != nil { + t.Errorf("serialization of %v failed with error %v", test.val, err) + } + raw := buf.Bytes() + if string(raw) != strings.ReplaceAll(test.buf, "\n", "\r\n") { + t.Errorf("expected %+#v to serialize to '%s' but got '%s'", test.val, test.buf, string(raw)) + } + }) + } +} diff --git a/internal/apiform/richparam.go b/internal/apiform/richparam.go new file mode 100644 index 0000000..f58db95 --- /dev/null +++ b/internal/apiform/richparam.go @@ -0,0 +1,20 @@ +package apiform + +import ( + "github.com/ScrapeGraphAI/scrapegraph-sdk/packages/param" + "mime/multipart" + "reflect" +) + +func (e *encoder) newRichFieldTypeEncoder(t reflect.Type) encoderFunc { + f, _ := t.FieldByName("Value") + enc := e.newPrimitiveTypeEncoder(f.Type) + return func(key string, value reflect.Value, writer *multipart.Writer) error { + if opt, ok := value.Interface().(param.Optional); ok && opt.Valid() { + return enc(key, value.FieldByIndex(f.Index), writer) + } else if ok && param.IsNull(opt) { + return writer.WriteField(key, "null") + } + return nil + } +} diff --git a/internal/apiform/tag.go b/internal/apiform/tag.go new file mode 100644 index 0000000..736fc1e --- /dev/null +++ b/internal/apiform/tag.go @@ -0,0 +1,51 @@ +package apiform + +import ( + "reflect" + "strings" +) + +const jsonStructTag = "json" +const formStructTag = "form" +const formatStructTag = "format" + +type parsedStructTag struct { + name string + required bool + extras bool + metadata bool + omitzero bool +} + +func parseFormStructTag(field reflect.StructField) (tag parsedStructTag, ok bool) { + raw, ok := field.Tag.Lookup(formStructTag) + if !ok { + raw, ok = field.Tag.Lookup(jsonStructTag) + } + if !ok { + return + } + parts := strings.Split(raw, ",") + if len(parts) == 0 { + return tag, false + } + tag.name = parts[0] + for _, part := range parts[1:] { + switch part { + case "required": + tag.required = true + case "extras": + tag.extras = true + case "metadata": + tag.metadata = true + case "omitzero": + tag.omitzero = true + } + } + return +} + +func parseFormatStructTag(field reflect.StructField) (format string, ok bool) { + format, ok = field.Tag.Lookup(formatStructTag) + return +} diff --git a/internal/apijson/decodeparam_test.go b/internal/apijson/decodeparam_test.go new file mode 100644 index 0000000..b390afb --- /dev/null +++ b/internal/apijson/decodeparam_test.go @@ -0,0 +1,410 @@ +package apijson_test + +import ( + "encoding/json" + "fmt" + "github.com/ScrapeGraphAI/scrapegraph-sdk/internal/apijson" + "github.com/ScrapeGraphAI/scrapegraph-sdk/packages/param" + "reflect" + "testing" +) + +func TestOptionalDecoders(t *testing.T) { + cases := map[string]struct { + buf string + val any + }{ + + "opt_string_present": { + `"hello"`, + param.NewOpt("hello"), + }, + "opt_string_empty_present": { + `""`, + param.NewOpt(""), + }, + "opt_string_null": { + `null`, + param.Null[string](), + }, + "opt_string_null_with_whitespace": { + ` null `, + param.Null[string](), + }, + } + + for name, test := range cases { + t.Run(name, func(t *testing.T) { + result := reflect.New(reflect.TypeOf(test.val)) + if err := json.Unmarshal([]byte(test.buf), result.Interface()); err != nil { + t.Fatalf("deserialization of %v failed with error %v", result, err) + } + + if !reflect.DeepEqual(result.Elem().Interface(), test.val) { + t.Fatalf("expected '%s' to deserialize to \n%#v\nbut got\n%#v", test.buf, test.val, result.Elem().Interface()) + } + }) + } +} + +type paramObject = param.APIObject + +type BasicObject struct { + ReqInt int64 `json:"req_int,required"` + ReqFloat float64 `json:"req_float,required"` + ReqString string `json:"req_string,required"` + ReqBool bool `json:"req_bool,required"` + + OptInt param.Opt[int64] `json:"opt_int"` + OptFloat param.Opt[float64] `json:"opt_float"` + OptString param.Opt[string] `json:"opt_string"` + OptBool param.Opt[bool] `json:"opt_bool"` + + paramObject +} + +func (o *BasicObject) UnmarshalJSON(data []byte) error { return apijson.UnmarshalRoot(data, o) } + +func TestBasicObjectWithNull(t *testing.T) { + raw := `{"opt_int":null,"opt_string":null,"opt_bool":null}` + var dst BasicObject + target := BasicObject{ + OptInt: param.Null[int64](), + // OptFloat: param.Opt[float64]{}, + OptString: param.Null[string](), + OptBool: param.Null[bool](), + } + + err := json.Unmarshal([]byte(raw), &dst) + if err != nil { + t.Fatalf("failed unmarshal") + } + + if !reflect.DeepEqual(dst, target) { + t.Fatalf("failed equality check %#v", dst) + } +} + +func TestBasicObject(t *testing.T) { + raw := `{"req_int":1,"req_float":1.3,"req_string":"test","req_bool":true,"opt_int":2,"opt_float":2.0,"opt_string":"test","opt_bool":false}` + var dst BasicObject + target := BasicObject{ + ReqInt: 1, + ReqFloat: 1.3, + ReqString: "test", + ReqBool: true, + OptInt: param.NewOpt[int64](2), + OptFloat: param.NewOpt(2.0), + OptString: param.NewOpt("test"), + OptBool: param.NewOpt(false), + } + + err := json.Unmarshal([]byte(raw), &dst) + if err != nil { + t.Fatalf("failed unmarshal") + } + + if !reflect.DeepEqual(dst, target) { + t.Fatalf("failed equality check %#v", dst) + } +} + +type ComplexObject struct { + Basic BasicObject `json:"basic,required"` + Enum string `json:"enum"` + paramObject +} + +func (o *ComplexObject) UnmarshalJSON(data []byte) error { return apijson.UnmarshalRoot(data, o) } + +func init() { + apijson.RegisterFieldValidator[ComplexObject]("enum", "a", "b", "c") +} + +func TestComplexObject(t *testing.T) { + raw := `{"basic":{"req_int":1,"req_float":1.3,"req_string":"test","req_bool":true,"opt_int":2,"opt_float":2.0,"opt_string":"test","opt_bool":false},"enum":"a"}` + var dst ComplexObject + + target := ComplexObject{ + Basic: BasicObject{ + ReqInt: 1, + ReqFloat: 1.3, + ReqString: "test", + ReqBool: true, + OptInt: param.NewOpt[int64](2), + OptFloat: param.NewOpt(2.0), + OptString: param.NewOpt("test"), + OptBool: param.NewOpt(false), + }, + Enum: "a", + } + + err := json.Unmarshal([]byte(raw), &dst) + if err != nil { + t.Fatalf("failed unmarshal") + } + + if !reflect.DeepEqual(dst, target) { + t.Fatalf("failed equality check %#v", dst) + } +} + +type paramUnion = param.APIUnion + +type MemberA struct { + Name string `json:"name,required"` + Age int `json:"age,required"` +} + +type MemberB struct { + Name string `json:"name,required"` + Age string `json:"age,required"` +} + +type MemberC struct { + Name string `json:"name,required"` + Age int `json:"age,required"` + Status string `json:"status"` +} + +type MemberD struct { + Cost int `json:"cost,required"` + Status string `json:"status,required"` +} + +type MemberE struct { + Cost int `json:"cost,required"` + Status string `json:"status,required"` +} + +type MemberF struct { + D int `json:"d"` + E string `json:"e"` + F float64 `json:"f"` + G param.Opt[int] `json:"g"` +} + +type MemberG struct { + D int `json:"d"` + E string `json:"e"` + F float64 `json:"f"` + G param.Opt[bool] `json:"g"` +} + +func init() { + apijson.RegisterFieldValidator[MemberD]("status", "good", "ok", "bad") + apijson.RegisterFieldValidator[MemberE]("status", "GOOD", "OK", "BAD") +} + +type UnionStruct struct { + OfMemberA *MemberA `json:",inline"` + OfMemberB *MemberB `json:",inline"` + OfMemberC *MemberC `json:",inline"` + OfMemberD *MemberD `json:",inline"` + OfMemberE *MemberE `json:",inline"` + OfMemberF *MemberF `json:",inline"` + OfMemberG *MemberG `json:",inline"` + OfString param.Opt[string] `json:",inline"` + + paramUnion +} + +func (union *UnionStruct) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, union) +} + +func TestUnionStruct(t *testing.T) { + tests := map[string]struct { + raw string + target UnionStruct + shouldFail bool + }{ + "fail": { + raw: `1200`, + target: UnionStruct{}, + shouldFail: true, + }, + "easy": { + raw: `{"age":30}`, + target: UnionStruct{OfMemberA: &MemberA{Age: 30}}, + }, + "less-easy": { + raw: `{"age":"thirty"}`, + target: UnionStruct{OfMemberB: &MemberB{Age: "thirty"}}, + }, + "even-less-easy": { + raw: `{"age":"30"}`, + target: UnionStruct{OfMemberB: &MemberB{Age: "30"}}, + }, + "medium": { + raw: `{"name":"jacob","age":30}`, + target: UnionStruct{OfMemberA: &MemberA{ + Age: 30, + Name: "jacob", + }}, + }, + "less-medium": { + raw: `{"name":"jacob","age":"thirty"}`, + target: UnionStruct{OfMemberB: &MemberB{ + Age: "thirty", + Name: "jacob", + }}, + }, + "even-less-medium": { + raw: `{"name":"jacob","age":"30"}`, + target: UnionStruct{OfMemberB: &MemberB{ + Name: "jacob", + Age: "30", + }}, + }, + "hard": { + raw: `{"name":"jacob","age":30,"status":"active"}`, + target: UnionStruct{OfMemberC: &MemberC{ + Name: "jacob", + Age: 30, + Status: "active", + }}, + }, + "inline-string": { + raw: `"hello there"`, + target: UnionStruct{OfString: param.NewOpt("hello there")}, + }, + "enum-field": { + raw: `{"cost":100,"status":"ok"}`, + target: UnionStruct{OfMemberD: &MemberD{Cost: 100, Status: "ok"}}, + }, + "other-enum-field": { + raw: `{"cost":100,"status":"GOOD"}`, + target: UnionStruct{OfMemberE: &MemberE{Cost: 100, Status: "GOOD"}}, + }, + "tricky-extra-fields": { + raw: `{"d":12,"e":"hello","f":1.00}`, + target: UnionStruct{OfMemberF: &MemberF{D: 12, E: "hello", F: 1.00}}, + }, + "optional-fields": { + raw: `{"d":12,"e":"hello","f":1.00,"g":12}`, + target: UnionStruct{OfMemberF: &MemberF{D: 12, E: "hello", F: 1.00, G: param.NewOpt(12)}}, + }, + "optional-fields-2": { + raw: `{"d":12,"e":"hello","f":1.00,"g":false}`, + target: UnionStruct{OfMemberG: &MemberG{D: 12, E: "hello", F: 1.00, G: param.NewOpt(false)}}, + }, + } + + for name, test := range tests { + var dst UnionStruct + t.Run(name, func(t *testing.T) { + err := json.Unmarshal([]byte(test.raw), &dst) + if err != nil && !test.shouldFail { + t.Fatalf("failed unmarshal with err: %v %#v", err, dst) + } + + if !reflect.DeepEqual(dst, test.target) { + if dst.OfMemberA != nil { + fmt.Printf("%#v", dst.OfMemberA) + } + t.Fatalf("failed equality, got %#v but expected %#v", dst, test.target) + } + }) + } +} + +type ConstantA string +type ConstantB string +type ConstantC string + +func (c ConstantA) Default() string { return "A" } +func (c ConstantB) Default() string { return "B" } +func (c ConstantC) Default() string { return "C" } + +type DiscVariantA struct { + Name string `json:"name,required"` + Age int `json:"age,required"` + Type ConstantA `json:"type,required"` +} + +type DiscVariantB struct { + Name string `json:"name,required"` + Age int `json:"age,required"` + Type ConstantB `json:"type,required"` +} + +type DiscVariantC struct { + Name string `json:"name,required"` + Age float64 `json:"age,required"` + Type ConstantC `json:"type,required"` +} + +type DiscriminatedUnion struct { + OfA *DiscVariantA `json:",inline"` + OfB *DiscVariantB `json:",inline"` + OfC *DiscVariantC `json:",inline"` + + paramUnion +} + +func init() { + apijson.RegisterDiscriminatedUnion[DiscriminatedUnion]("type", map[string]reflect.Type{ + "A": reflect.TypeOf(DiscVariantA{}), + "B": reflect.TypeOf(DiscVariantB{}), + "C": reflect.TypeOf(DiscVariantC{}), + }) +} + +func (d *DiscriminatedUnion) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, d) +} + +func TestDiscriminatedUnion(t *testing.T) { + tests := map[string]struct { + raw string + target DiscriminatedUnion + shouldFail bool + }{ + "variant_a": { + raw: `{"name":"Alice","age":25,"type":"A"}`, + target: DiscriminatedUnion{OfA: &DiscVariantA{ + Name: "Alice", + Age: 25, + Type: "A", + }}, + }, + "variant_b": { + raw: `{"name":"Bob","age":30,"type":"B"}`, + target: DiscriminatedUnion{OfB: &DiscVariantB{ + Name: "Bob", + Age: 30, + Type: "B", + }}, + }, + "variant_c": { + raw: `{"name":"Charlie","age":35.5,"type":"C"}`, + target: DiscriminatedUnion{OfC: &DiscVariantC{ + Name: "Charlie", + Age: 35.5, + Type: "C", + }}, + }, + "invalid_type": { + raw: `{"name":"Unknown","age":40,"type":"D"}`, + target: DiscriminatedUnion{}, + shouldFail: true, + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + var dst DiscriminatedUnion + err := json.Unmarshal([]byte(test.raw), &dst) + if err != nil && !test.shouldFail { + t.Fatalf("failed unmarshal with err: %v", err) + } + if err == nil && test.shouldFail { + t.Fatalf("expected unmarshal to fail but it succeeded") + } + if !reflect.DeepEqual(dst, test.target) { + t.Fatalf("failed equality, got %#v but expected %#v", dst, test.target) + } + }) + } +} diff --git a/internal/apijson/decoder.go b/internal/apijson/decoder.go new file mode 100644 index 0000000..43a083c --- /dev/null +++ b/internal/apijson/decoder.go @@ -0,0 +1,691 @@ +// The deserialization algorithm from apijson may be subject to improvements +// between minor versions, particularly with respect to calling [json.Unmarshal] +// into param unions. + +package apijson + +import ( + "encoding/json" + "fmt" + "github.com/ScrapeGraphAI/scrapegraph-sdk/packages/param" + "reflect" + "strconv" + "sync" + "time" + "unsafe" + + "github.com/tidwall/gjson" +) + +// decoders is a synchronized map with roughly the following type: +// map[reflect.Type]decoderFunc +var decoders sync.Map + +// Unmarshal is similar to [encoding/json.Unmarshal] and parses the JSON-encoded +// data and stores it in the given pointer. +func Unmarshal(raw []byte, to any) error { + d := &decoderBuilder{dateFormat: time.RFC3339} + return d.unmarshal(raw, to) +} + +// UnmarshalRoot is like Unmarshal, but doesn't try to call MarshalJSON on the +// root element. Useful if a struct's UnmarshalJSON is overrode to use the +// behavior of this encoder versus the standard library. +func UnmarshalRoot(raw []byte, to any) error { + d := &decoderBuilder{dateFormat: time.RFC3339, root: true} + return d.unmarshal(raw, to) +} + +// decoderBuilder contains the 'compile-time' state of the decoder. +type decoderBuilder struct { + // Whether or not this is the first element and called by [UnmarshalRoot], see + // the documentation there to see why this is necessary. + root bool + // The dateFormat (a format string for [time.Format]) which is chosen by the + // last struct tag that was seen. + dateFormat string +} + +// decoderState contains the 'run-time' state of the decoder. +type decoderState struct { + strict bool + exactness exactness + validator *validationEntry +} + +// Exactness refers to how close to the type the result was if deserialization +// was successful. This is useful in deserializing unions, where you want to try +// each entry, first with strict, then with looser validation, without actually +// having to do a lot of redundant work by marshalling twice (or maybe even more +// times). +type exactness int8 + +const ( + // Some values had to fudged a bit, for example by converting a string to an + // int, or an enum with extra values. + loose exactness = iota + // There are some extra arguments, but other wise it matches the union. + extras + // Exactly right. + exact +) + +type decoderFunc func(node gjson.Result, value reflect.Value, state *decoderState) error + +type decoderField struct { + tag parsedStructTag + fn decoderFunc + idx []int + goname string +} + +type decoderEntry struct { + reflect.Type + dateFormat string + root bool +} + +func (d *decoderBuilder) unmarshal(raw []byte, to any) error { + value := reflect.ValueOf(to).Elem() + result := gjson.ParseBytes(raw) + if !value.IsValid() { + return fmt.Errorf("apijson: cannot marshal into invalid value") + } + return d.typeDecoder(value.Type())(result, value, &decoderState{strict: false, exactness: exact}) +} + +// unmarshalWithExactness is used for internal testing purposes. +func (d *decoderBuilder) unmarshalWithExactness(raw []byte, to any) (exactness, error) { + value := reflect.ValueOf(to).Elem() + result := gjson.ParseBytes(raw) + if !value.IsValid() { + return 0, fmt.Errorf("apijson: cannot marshal into invalid value") + } + state := decoderState{strict: false, exactness: exact} + err := d.typeDecoder(value.Type())(result, value, &state) + return state.exactness, err +} + +func (d *decoderBuilder) typeDecoder(t reflect.Type) decoderFunc { + entry := decoderEntry{ + Type: t, + dateFormat: d.dateFormat, + root: d.root, + } + + if fi, ok := decoders.Load(entry); ok { + return fi.(decoderFunc) + } + + // To deal with recursive types, populate the map with an + // indirect func before we build it. This type waits on the + // real func (f) to be ready and then calls it. This indirect + // func is only used for recursive types. + var ( + wg sync.WaitGroup + f decoderFunc + ) + wg.Add(1) + fi, loaded := decoders.LoadOrStore(entry, decoderFunc(func(node gjson.Result, v reflect.Value, state *decoderState) error { + wg.Wait() + return f(node, v, state) + })) + if loaded { + return fi.(decoderFunc) + } + + // Compute the real decoder and replace the indirect func with it. + f = d.newTypeDecoder(t) + wg.Done() + decoders.Store(entry, f) + return f +} + +// validatedTypeDecoder wraps the type decoder with a validator. This is helpful +// for ensuring that enum fields are correct. +func (d *decoderBuilder) validatedTypeDecoder(t reflect.Type, entry *validationEntry) decoderFunc { + dec := d.typeDecoder(t) + if entry == nil { + return dec + } + + // Thread the current validation entry through the decoder, + // but clean up in time for the next field. + return func(node gjson.Result, v reflect.Value, state *decoderState) error { + state.validator = entry + err := dec(node, v, state) + state.validator = nil + return err + } +} + +func indirectUnmarshalerDecoder(n gjson.Result, v reflect.Value, state *decoderState) error { + return v.Addr().Interface().(json.Unmarshaler).UnmarshalJSON([]byte(n.Raw)) +} + +func unmarshalerDecoder(n gjson.Result, v reflect.Value, state *decoderState) error { + if v.Kind() == reflect.Pointer && v.CanSet() { + v.Set(reflect.New(v.Type().Elem())) + } + return v.Interface().(json.Unmarshaler).UnmarshalJSON([]byte(n.Raw)) +} + +func (d *decoderBuilder) newTypeDecoder(t reflect.Type) decoderFunc { + if t.ConvertibleTo(reflect.TypeOf(time.Time{})) { + return d.newTimeTypeDecoder(t) + } + + if t.Implements(reflect.TypeOf((*param.Optional)(nil)).Elem()) { + return d.newOptTypeDecoder(t) + } + + if !d.root && t.Implements(reflect.TypeOf((*json.Unmarshaler)(nil)).Elem()) { + return unmarshalerDecoder + } + if !d.root && reflect.PointerTo(t).Implements(reflect.TypeOf((*json.Unmarshaler)(nil)).Elem()) { + if _, ok := unionVariants[t]; !ok { + return indirectUnmarshalerDecoder + } + } + d.root = false + + if _, ok := unionRegistry[t]; ok { + if isStructUnion(t) { + return d.newStructUnionDecoder(t) + } + return d.newUnionDecoder(t) + } + + switch t.Kind() { + case reflect.Pointer: + inner := t.Elem() + innerDecoder := d.typeDecoder(inner) + + return func(n gjson.Result, v reflect.Value, state *decoderState) error { + if !v.IsValid() { + return fmt.Errorf("apijson: unexpected invalid reflection value %+#v", v) + } + + newValue := reflect.New(inner).Elem() + err := innerDecoder(n, newValue, state) + if err != nil { + return err + } + + v.Set(newValue.Addr()) + return nil + } + case reflect.Struct: + if isStructUnion(t) { + return d.newStructUnionDecoder(t) + } + return d.newStructTypeDecoder(t) + case reflect.Array: + fallthrough + case reflect.Slice: + return d.newArrayTypeDecoder(t) + case reflect.Map: + return d.newMapDecoder(t) + case reflect.Interface: + return func(node gjson.Result, value reflect.Value, state *decoderState) error { + if !value.IsValid() { + return fmt.Errorf("apijson: unexpected invalid value %+#v", value) + } + if node.Value() != nil && value.CanSet() { + value.Set(reflect.ValueOf(node.Value())) + } + return nil + } + default: + return d.newPrimitiveTypeDecoder(t) + } +} + +func (d *decoderBuilder) newMapDecoder(t reflect.Type) decoderFunc { + keyType := t.Key() + itemType := t.Elem() + itemDecoder := d.typeDecoder(itemType) + + return func(node gjson.Result, value reflect.Value, state *decoderState) (err error) { + mapValue := reflect.MakeMapWithSize(t, len(node.Map())) + + node.ForEach(func(key, value gjson.Result) bool { + // It's fine for us to just use `ValueOf` here because the key types will + // always be primitive types so we don't need to decode it using the standard pattern + keyValue := reflect.ValueOf(key.Value()) + if !keyValue.IsValid() { + if err == nil { + err = fmt.Errorf("apijson: received invalid key type %v", keyValue.String()) + } + return false + } + if keyValue.Type() != keyType { + if err == nil { + err = fmt.Errorf("apijson: expected key type %v but got %v", keyType, keyValue.Type()) + } + return false + } + + itemValue := reflect.New(itemType).Elem() + itemerr := itemDecoder(value, itemValue, state) + if itemerr != nil { + if err == nil { + err = itemerr + } + return false + } + + mapValue.SetMapIndex(keyValue, itemValue) + return true + }) + + if err != nil { + return err + } + value.Set(mapValue) + return nil + } +} + +func (d *decoderBuilder) newArrayTypeDecoder(t reflect.Type) decoderFunc { + itemDecoder := d.typeDecoder(t.Elem()) + + return func(node gjson.Result, value reflect.Value, state *decoderState) (err error) { + if !node.IsArray() { + return fmt.Errorf("apijson: could not deserialize to an array") + } + + arrayNode := node.Array() + + arrayValue := reflect.MakeSlice(reflect.SliceOf(t.Elem()), len(arrayNode), len(arrayNode)) + for i, itemNode := range arrayNode { + err = itemDecoder(itemNode, arrayValue.Index(i), state) + if err != nil { + return err + } + } + + value.Set(arrayValue) + return nil + } +} + +func (d *decoderBuilder) newStructTypeDecoder(t reflect.Type) decoderFunc { + // map of json field name to struct field decoders + decoderFields := map[string]decoderField{} + anonymousDecoders := []decoderField{} + extraDecoder := (*decoderField)(nil) + var inlineDecoders []decoderField + + validationEntries := validationRegistry[t] + + for i := 0; i < t.NumField(); i++ { + idx := []int{i} + field := t.FieldByIndex(idx) + if !field.IsExported() { + continue + } + + var validator *validationEntry + for _, entry := range validationEntries { + if entry.field.Offset == field.Offset { + validator = &entry + break + } + } + + // If this is an embedded struct, traverse one level deeper to extract + // the fields and get their encoders as well. + if field.Anonymous { + anonymousDecoders = append(anonymousDecoders, decoderField{ + fn: d.typeDecoder(field.Type), + idx: idx[:], + }) + continue + } + // If json tag is not present, then we skip, which is intentionally + // different behavior from the stdlib. + ptag, ok := parseJSONStructTag(field) + if !ok { + continue + } + // We only want to support unexported fields if they're tagged with + // `extras` because that field shouldn't be part of the public API. + if ptag.extras { + extraDecoder = &decoderField{ptag, d.typeDecoder(field.Type.Elem()), idx, field.Name} + continue + } + if ptag.inline { + df := decoderField{ptag, d.typeDecoder(field.Type), idx, field.Name} + inlineDecoders = append(inlineDecoders, df) + continue + } + if ptag.metadata { + continue + } + + oldFormat := d.dateFormat + dateFormat, ok := parseFormatStructTag(field) + if ok { + switch dateFormat { + case "date-time": + d.dateFormat = time.RFC3339 + case "date": + d.dateFormat = "2006-01-02" + } + } + + decoderFields[ptag.name] = decoderField{ + ptag, + d.validatedTypeDecoder(field.Type, validator), + idx, field.Name, + } + + d.dateFormat = oldFormat + } + + return func(node gjson.Result, value reflect.Value, state *decoderState) (err error) { + if field := value.FieldByName("JSON"); field.IsValid() { + if raw := field.FieldByName("raw"); raw.IsValid() { + setUnexportedField(raw, node.Raw) + } + } + + for _, decoder := range anonymousDecoders { + // ignore errors + decoder.fn(node, value.FieldByIndex(decoder.idx), state) + } + + for _, inlineDecoder := range inlineDecoders { + var meta Field + dest := value.FieldByIndex(inlineDecoder.idx) + isValid := false + if dest.IsValid() && node.Type != gjson.Null { + inlineState := decoderState{exactness: state.exactness, strict: true} + err = inlineDecoder.fn(node, dest, &inlineState) + if err == nil { + isValid = true + } + } + + if node.Type == gjson.Null { + meta = Field{ + raw: node.Raw, + status: null, + } + } else if !isValid { + // If an inline decoder fails, unset the field and move on. + if dest.IsValid() { + dest.SetZero() + } + continue + } else if isValid { + meta = Field{ + raw: node.Raw, + status: valid, + } + } + setMetadataSubField(value, inlineDecoder.idx, inlineDecoder.goname, meta) + } + + typedExtraType := reflect.Type(nil) + typedExtraFields := reflect.Value{} + if extraDecoder != nil { + typedExtraType = value.FieldByIndex(extraDecoder.idx).Type() + typedExtraFields = reflect.MakeMap(typedExtraType) + } + untypedExtraFields := map[string]Field{} + + for fieldName, itemNode := range node.Map() { + df, explicit := decoderFields[fieldName] + var ( + dest reflect.Value + fn decoderFunc + meta Field + ) + if explicit { + fn = df.fn + dest = value.FieldByIndex(df.idx) + } + if !explicit && extraDecoder != nil { + dest = reflect.New(typedExtraType.Elem()).Elem() + fn = extraDecoder.fn + } + + isValid := false + if dest.IsValid() && itemNode.Type != gjson.Null { + err = fn(itemNode, dest, state) + if err == nil { + isValid = true + } + } + + // Handle null [param.Opt] + if itemNode.Type == gjson.Null && dest.IsValid() && dest.Type().Implements(reflect.TypeOf((*param.Optional)(nil)).Elem()) { + dest.Addr().Interface().(json.Unmarshaler).UnmarshalJSON([]byte(itemNode.Raw)) + continue + } + + if itemNode.Type == gjson.Null { + meta = Field{ + raw: itemNode.Raw, + status: null, + } + } else if !isValid { + meta = Field{ + raw: itemNode.Raw, + status: invalid, + } + } else if isValid { + meta = Field{ + raw: itemNode.Raw, + status: valid, + } + } + + if explicit { + setMetadataSubField(value, df.idx, df.goname, meta) + } + if !explicit { + untypedExtraFields[fieldName] = meta + } + if !explicit && extraDecoder != nil { + typedExtraFields.SetMapIndex(reflect.ValueOf(fieldName), dest) + } + } + + if extraDecoder != nil && typedExtraFields.Len() > 0 { + value.FieldByIndex(extraDecoder.idx).Set(typedExtraFields) + } + + // Set exactness to 'extras' if there are untyped, extra fields. + if len(untypedExtraFields) > 0 && state.exactness > extras { + state.exactness = extras + } + + if len(untypedExtraFields) > 0 { + setMetadataExtraFields(value, []int{-1}, "ExtraFields", untypedExtraFields) + } + return nil + } +} + +func (d *decoderBuilder) newPrimitiveTypeDecoder(t reflect.Type) decoderFunc { + switch t.Kind() { + case reflect.String: + return func(n gjson.Result, v reflect.Value, state *decoderState) error { + v.SetString(n.String()) + if guardStrict(state, n.Type != gjson.String) { + return fmt.Errorf("apijson: failed to parse string strictly") + } + // Everything that is not an object can be loosely stringified. + if n.Type == gjson.JSON { + return fmt.Errorf("apijson: failed to parse string") + } + + state.validateString(v) + + if guardUnknown(state, v) { + return fmt.Errorf("apijson: failed string enum validation") + } + return nil + } + case reflect.Bool: + return func(n gjson.Result, v reflect.Value, state *decoderState) error { + v.SetBool(n.Bool()) + if guardStrict(state, n.Type != gjson.True && n.Type != gjson.False) { + return fmt.Errorf("apijson: failed to parse bool strictly") + } + // Numbers and strings that are either 'true' or 'false' can be loosely + // deserialized as bool. + if n.Type == gjson.String && (n.Raw != "true" && n.Raw != "false") || n.Type == gjson.JSON { + return fmt.Errorf("apijson: failed to parse bool") + } + + state.validateBool(v) + + if guardUnknown(state, v) { + return fmt.Errorf("apijson: failed bool enum validation") + } + return nil + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return func(n gjson.Result, v reflect.Value, state *decoderState) error { + v.SetInt(n.Int()) + if guardStrict(state, n.Type != gjson.Number || n.Num != float64(int(n.Num))) { + return fmt.Errorf("apijson: failed to parse int strictly") + } + // Numbers, booleans, and strings that maybe look like numbers can be + // loosely deserialized as numbers. + if n.Type == gjson.JSON || (n.Type == gjson.String && !canParseAsNumber(n.Str)) { + return fmt.Errorf("apijson: failed to parse int") + } + + state.validateInt(v) + + if guardUnknown(state, v) { + return fmt.Errorf("apijson: failed int enum validation") + } + return nil + } + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return func(n gjson.Result, v reflect.Value, state *decoderState) error { + v.SetUint(n.Uint()) + if guardStrict(state, n.Type != gjson.Number || n.Num != float64(int(n.Num)) || n.Num < 0) { + return fmt.Errorf("apijson: failed to parse uint strictly") + } + // Numbers, booleans, and strings that maybe look like numbers can be + // loosely deserialized as uint. + if n.Type == gjson.JSON || (n.Type == gjson.String && !canParseAsNumber(n.Str)) { + return fmt.Errorf("apijson: failed to parse uint") + } + if guardUnknown(state, v) { + return fmt.Errorf("apijson: failed uint enum validation") + } + return nil + } + case reflect.Float32, reflect.Float64: + return func(n gjson.Result, v reflect.Value, state *decoderState) error { + v.SetFloat(n.Float()) + if guardStrict(state, n.Type != gjson.Number) { + return fmt.Errorf("apijson: failed to parse float strictly") + } + // Numbers, booleans, and strings that maybe look like numbers can be + // loosely deserialized as floats. + if n.Type == gjson.JSON || (n.Type == gjson.String && !canParseAsNumber(n.Str)) { + return fmt.Errorf("apijson: failed to parse float") + } + if guardUnknown(state, v) { + return fmt.Errorf("apijson: failed float enum validation") + } + return nil + } + default: + return func(node gjson.Result, v reflect.Value, state *decoderState) error { + return fmt.Errorf("unknown type received at primitive decoder: %s", t.String()) + } + } +} + +func (d *decoderBuilder) newOptTypeDecoder(t reflect.Type) decoderFunc { + for t.Kind() == reflect.Pointer { + t = t.Elem() + } + valueField, _ := t.FieldByName("Value") + return func(n gjson.Result, v reflect.Value, state *decoderState) error { + state.validateOptKind(n, valueField.Type) + return v.Addr().Interface().(json.Unmarshaler).UnmarshalJSON([]byte(n.Raw)) + } +} + +func (d *decoderBuilder) newTimeTypeDecoder(t reflect.Type) decoderFunc { + format := d.dateFormat + return func(n gjson.Result, v reflect.Value, state *decoderState) error { + parsed, err := time.Parse(format, n.Str) + if err == nil { + v.Set(reflect.ValueOf(parsed).Convert(t)) + return nil + } + + if guardStrict(state, true) { + return err + } + + layouts := []string{ + "2006-01-02", + "2006-01-02T15:04:05Z07:00", + "2006-01-02T15:04:05Z0700", + "2006-01-02T15:04:05", + "2006-01-02 15:04:05Z07:00", + "2006-01-02 15:04:05Z0700", + "2006-01-02 15:04:05", + } + + for _, layout := range layouts { + parsed, err := time.Parse(layout, n.Str) + if err == nil { + v.Set(reflect.ValueOf(parsed).Convert(t)) + return nil + } + } + + return fmt.Errorf("unable to leniently parse date-time string: %s", n.Str) + } +} + +func setUnexportedField(field reflect.Value, value any) { + reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem().Set(reflect.ValueOf(value)) +} + +func guardStrict(state *decoderState, cond bool) bool { + if !cond { + return false + } + + if state.strict { + return true + } + + state.exactness = loose + return false +} + +func canParseAsNumber(str string) bool { + _, err := strconv.ParseFloat(str, 64) + return err == nil +} + +var stringType = reflect.TypeOf(string("")) + +func guardUnknown(state *decoderState, v reflect.Value) bool { + if have, ok := v.Interface().(interface{ IsKnown() bool }); guardStrict(state, ok && !have.IsKnown()) { + return true + } + + constantString, ok := v.Interface().(interface{ Default() string }) + named := v.Type() != stringType + if guardStrict(state, ok && named && v.Equal(reflect.ValueOf(constantString.Default()))) { + return true + } + return false +} diff --git a/internal/apijson/decoderesp_test.go b/internal/apijson/decoderesp_test.go new file mode 100644 index 0000000..f55d0e4 --- /dev/null +++ b/internal/apijson/decoderesp_test.go @@ -0,0 +1,30 @@ +package apijson_test + +import ( + "encoding/json" + "github.com/ScrapeGraphAI/scrapegraph-sdk/internal/apijson" + "github.com/ScrapeGraphAI/scrapegraph-sdk/packages/respjson" + "testing" +) + +type StructWithNullExtraField struct { + Results []string `json:"results,required"` + JSON struct { + Results respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +func (r *StructWithNullExtraField) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +func TestDecodeWithNullExtraField(t *testing.T) { + raw := `{"something_else":null}` + var dst *StructWithNullExtraField + err := json.Unmarshal([]byte(raw), &dst) + if err != nil { + t.Fatalf("error: %s", err.Error()) + } +} diff --git a/internal/apijson/encoder.go b/internal/apijson/encoder.go new file mode 100644 index 0000000..8358a2f --- /dev/null +++ b/internal/apijson/encoder.go @@ -0,0 +1,392 @@ +package apijson + +import ( + "bytes" + "encoding/json" + "fmt" + "reflect" + "sort" + "strconv" + "strings" + "sync" + "time" + + "github.com/tidwall/sjson" +) + +var encoders sync.Map // map[encoderEntry]encoderFunc + +func Marshal(value any) ([]byte, error) { + e := &encoder{dateFormat: time.RFC3339} + return e.marshal(value) +} + +func MarshalRoot(value any) ([]byte, error) { + e := &encoder{root: true, dateFormat: time.RFC3339} + return e.marshal(value) +} + +type encoder struct { + dateFormat string + root bool +} + +type encoderFunc func(value reflect.Value) ([]byte, error) + +type encoderField struct { + tag parsedStructTag + fn encoderFunc + idx []int +} + +type encoderEntry struct { + reflect.Type + dateFormat string + root bool +} + +func (e *encoder) marshal(value any) ([]byte, error) { + val := reflect.ValueOf(value) + if !val.IsValid() { + return nil, nil + } + typ := val.Type() + enc := e.typeEncoder(typ) + return enc(val) +} + +func (e *encoder) typeEncoder(t reflect.Type) encoderFunc { + entry := encoderEntry{ + Type: t, + dateFormat: e.dateFormat, + root: e.root, + } + + if fi, ok := encoders.Load(entry); ok { + return fi.(encoderFunc) + } + + // To deal with recursive types, populate the map with an + // indirect func before we build it. This type waits on the + // real func (f) to be ready and then calls it. This indirect + // func is only used for recursive types. + var ( + wg sync.WaitGroup + f encoderFunc + ) + wg.Add(1) + fi, loaded := encoders.LoadOrStore(entry, encoderFunc(func(v reflect.Value) ([]byte, error) { + wg.Wait() + return f(v) + })) + if loaded { + return fi.(encoderFunc) + } + + // Compute the real encoder and replace the indirect func with it. + f = e.newTypeEncoder(t) + wg.Done() + encoders.Store(entry, f) + return f +} + +func marshalerEncoder(v reflect.Value) ([]byte, error) { + return v.Interface().(json.Marshaler).MarshalJSON() +} + +func indirectMarshalerEncoder(v reflect.Value) ([]byte, error) { + return v.Addr().Interface().(json.Marshaler).MarshalJSON() +} + +func (e *encoder) newTypeEncoder(t reflect.Type) encoderFunc { + if t.ConvertibleTo(reflect.TypeOf(time.Time{})) { + return e.newTimeTypeEncoder() + } + if !e.root && t.Implements(reflect.TypeOf((*json.Marshaler)(nil)).Elem()) { + return marshalerEncoder + } + if !e.root && reflect.PointerTo(t).Implements(reflect.TypeOf((*json.Marshaler)(nil)).Elem()) { + return indirectMarshalerEncoder + } + e.root = false + switch t.Kind() { + case reflect.Pointer: + inner := t.Elem() + + innerEncoder := e.typeEncoder(inner) + return func(v reflect.Value) ([]byte, error) { + if !v.IsValid() || v.IsNil() { + return nil, nil + } + return innerEncoder(v.Elem()) + } + case reflect.Struct: + return e.newStructTypeEncoder(t) + case reflect.Array: + fallthrough + case reflect.Slice: + return e.newArrayTypeEncoder(t) + case reflect.Map: + return e.newMapEncoder(t) + case reflect.Interface: + return e.newInterfaceEncoder() + default: + return e.newPrimitiveTypeEncoder(t) + } +} + +func (e *encoder) newPrimitiveTypeEncoder(t reflect.Type) encoderFunc { + switch t.Kind() { + // Note that we could use `gjson` to encode these types but it would complicate our + // code more and this current code shouldn't cause any issues + case reflect.String: + return func(v reflect.Value) ([]byte, error) { + return json.Marshal(v.Interface()) + } + case reflect.Bool: + return func(v reflect.Value) ([]byte, error) { + if v.Bool() { + return []byte("true"), nil + } + return []byte("false"), nil + } + case reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64: + return func(v reflect.Value) ([]byte, error) { + return []byte(strconv.FormatInt(v.Int(), 10)), nil + } + case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return func(v reflect.Value) ([]byte, error) { + return []byte(strconv.FormatUint(v.Uint(), 10)), nil + } + case reflect.Float32: + return func(v reflect.Value) ([]byte, error) { + return []byte(strconv.FormatFloat(v.Float(), 'f', -1, 32)), nil + } + case reflect.Float64: + return func(v reflect.Value) ([]byte, error) { + return []byte(strconv.FormatFloat(v.Float(), 'f', -1, 64)), nil + } + default: + return func(v reflect.Value) ([]byte, error) { + return nil, fmt.Errorf("unknown type received at primitive encoder: %s", t.String()) + } + } +} + +func (e *encoder) newArrayTypeEncoder(t reflect.Type) encoderFunc { + itemEncoder := e.typeEncoder(t.Elem()) + + return func(value reflect.Value) ([]byte, error) { + json := []byte("[]") + for i := 0; i < value.Len(); i++ { + var value, err = itemEncoder(value.Index(i)) + if err != nil { + return nil, err + } + if value == nil { + // Assume that empty items should be inserted as `null` so that the output array + // will be the same length as the input array + value = []byte("null") + } + + json, err = sjson.SetRawBytes(json, "-1", value) + if err != nil { + return nil, err + } + } + + return json, nil + } +} + +func (e *encoder) newStructTypeEncoder(t reflect.Type) encoderFunc { + encoderFields := []encoderField{} + extraEncoder := (*encoderField)(nil) + + // This helper allows us to recursively collect field encoders into a flat + // array. The parameter `index` keeps track of the access patterns necessary + // to get to some field. + var collectEncoderFields func(r reflect.Type, index []int) + collectEncoderFields = func(r reflect.Type, index []int) { + for i := 0; i < r.NumField(); i++ { + idx := append(index, i) + field := t.FieldByIndex(idx) + if !field.IsExported() { + continue + } + // If this is an embedded struct, traverse one level deeper to extract + // the field and get their encoders as well. + if field.Anonymous { + collectEncoderFields(field.Type, idx) + continue + } + // If json tag is not present, then we skip, which is intentionally + // different behavior from the stdlib. + ptag, ok := parseJSONStructTag(field) + if !ok { + continue + } + // We only want to support unexported field if they're tagged with + // `extras` because that field shouldn't be part of the public API. We + // also want to only keep the top level extras + if ptag.extras && len(index) == 0 { + extraEncoder = &encoderField{ptag, e.typeEncoder(field.Type.Elem()), idx} + continue + } + if ptag.name == "-" { + continue + } + + dateFormat, ok := parseFormatStructTag(field) + oldFormat := e.dateFormat + if ok { + switch dateFormat { + case "date-time": + e.dateFormat = time.RFC3339 + case "date": + e.dateFormat = "2006-01-02" + } + } + encoderFields = append(encoderFields, encoderField{ptag, e.typeEncoder(field.Type), idx}) + e.dateFormat = oldFormat + } + } + collectEncoderFields(t, []int{}) + + // Ensure deterministic output by sorting by lexicographic order + sort.Slice(encoderFields, func(i, j int) bool { + return encoderFields[i].tag.name < encoderFields[j].tag.name + }) + + return func(value reflect.Value) (json []byte, err error) { + json = []byte("{}") + + for _, ef := range encoderFields { + field := value.FieldByIndex(ef.idx) + encoded, err := ef.fn(field) + if err != nil { + return nil, err + } + if encoded == nil { + continue + } + json, err = sjson.SetRawBytes(json, ef.tag.name, encoded) + if err != nil { + return nil, err + } + } + + if extraEncoder != nil { + json, err = e.encodeMapEntries(json, value.FieldByIndex(extraEncoder.idx)) + if err != nil { + return nil, err + } + } + return + } +} + +func (e *encoder) newFieldTypeEncoder(t reflect.Type) encoderFunc { + f, _ := t.FieldByName("Value") + enc := e.typeEncoder(f.Type) + + return func(value reflect.Value) (json []byte, err error) { + present := value.FieldByName("Present") + if !present.Bool() { + return nil, nil + } + null := value.FieldByName("Null") + if null.Bool() { + return []byte("null"), nil + } + raw := value.FieldByName("Raw") + if !raw.IsNil() { + return e.typeEncoder(raw.Type())(raw) + } + return enc(value.FieldByName("Value")) + } +} + +func (e *encoder) newTimeTypeEncoder() encoderFunc { + format := e.dateFormat + return func(value reflect.Value) (json []byte, err error) { + return []byte(`"` + value.Convert(reflect.TypeOf(time.Time{})).Interface().(time.Time).Format(format) + `"`), nil + } +} + +func (e encoder) newInterfaceEncoder() encoderFunc { + return func(value reflect.Value) ([]byte, error) { + value = value.Elem() + if !value.IsValid() { + return nil, nil + } + return e.typeEncoder(value.Type())(value) + } +} + +// Given a []byte of json (may either be an empty object or an object that already contains entries) +// encode all of the entries in the map to the json byte array. +func (e *encoder) encodeMapEntries(json []byte, v reflect.Value) ([]byte, error) { + type mapPair struct { + key []byte + value reflect.Value + } + + pairs := []mapPair{} + keyEncoder := e.typeEncoder(v.Type().Key()) + + iter := v.MapRange() + for iter.Next() { + var encodedKeyString string + if iter.Key().Type().Kind() == reflect.String { + encodedKeyString = iter.Key().String() + } else { + var err error + encodedKeyBytes, err := keyEncoder(iter.Key()) + if err != nil { + return nil, err + } + encodedKeyString = string(encodedKeyBytes) + } + encodedKey := []byte(sjsonReplacer.Replace(encodedKeyString)) + pairs = append(pairs, mapPair{key: encodedKey, value: iter.Value()}) + } + + // Ensure deterministic output + sort.Slice(pairs, func(i, j int) bool { + return bytes.Compare(pairs[i].key, pairs[j].key) < 0 + }) + + elementEncoder := e.typeEncoder(v.Type().Elem()) + for _, p := range pairs { + encodedValue, err := elementEncoder(p.value) + if err != nil { + return nil, err + } + if len(encodedValue) == 0 { + continue + } + json, err = sjson.SetRawBytes(json, string(p.key), encodedValue) + if err != nil { + return nil, err + } + } + + return json, nil +} + +func (e *encoder) newMapEncoder(_ reflect.Type) encoderFunc { + return func(value reflect.Value) ([]byte, error) { + json := []byte("{}") + var err error + json, err = e.encodeMapEntries(json, value) + if err != nil { + return nil, err + } + return json, nil + } +} + +// If we want to set a literal key value into JSON using sjson, we need to make sure it doesn't have +// special characters that sjson interprets as a path. +var sjsonReplacer *strings.Replacer = strings.NewReplacer(".", "\\.", ":", "\\:", "*", "\\*") diff --git a/internal/apijson/enum.go b/internal/apijson/enum.go new file mode 100644 index 0000000..18b218a --- /dev/null +++ b/internal/apijson/enum.go @@ -0,0 +1,145 @@ +package apijson + +import ( + "fmt" + "reflect" + "slices" + "sync" + + "github.com/tidwall/gjson" +) + +/********************/ +/* Validating Enums */ +/********************/ + +type validationEntry struct { + field reflect.StructField + required bool + legalValues struct { + strings []string + // 1 represents true, 0 represents false, -1 represents either + bools int + ints []int64 + } +} + +type validatorFunc func(reflect.Value) exactness + +var validators sync.Map +var validationRegistry = map[reflect.Type][]validationEntry{} + +func RegisterFieldValidator[T any, V string | bool | int](fieldName string, values ...V) { + var t T + parentType := reflect.TypeOf(t) + + if _, ok := validationRegistry[parentType]; !ok { + validationRegistry[parentType] = []validationEntry{} + } + + // The following checks run at initialization time, + // it is impossible for them to panic if any tests pass. + if parentType.Kind() != reflect.Struct { + panic(fmt.Sprintf("apijson: cannot initialize validator for non-struct %s", parentType.String())) + } + + var field reflect.StructField + found := false + for i := 0; i < parentType.NumField(); i++ { + ptag, ok := parseJSONStructTag(parentType.Field(i)) + if ok && ptag.name == fieldName { + field = parentType.Field(i) + found = true + break + } + } + + if !found { + panic(fmt.Sprintf("apijson: cannot find field %s in struct %s", fieldName, parentType.String())) + } + + newEntry := validationEntry{field: field} + newEntry.legalValues.bools = -1 // default to either + + switch values := any(values).(type) { + case []string: + newEntry.legalValues.strings = values + case []int: + newEntry.legalValues.ints = make([]int64, len(values)) + for i, value := range values { + newEntry.legalValues.ints[i] = int64(value) + } + case []bool: + for i, value := range values { + var next int + if value { + next = 1 + } + if i > 0 && newEntry.legalValues.bools != next { + newEntry.legalValues.bools = -1 // accept either + break + } + newEntry.legalValues.bools = next + } + } + + // Store the information necessary to create a validator, so that we can use it + // lazily create the validator function when did. + validationRegistry[parentType] = append(validationRegistry[parentType], newEntry) +} + +func (state *decoderState) validateString(v reflect.Value) { + if state.validator == nil { + return + } + if !slices.Contains(state.validator.legalValues.strings, v.String()) { + state.exactness = loose + } +} + +func (state *decoderState) validateInt(v reflect.Value) { + if state.validator == nil { + return + } + if !slices.Contains(state.validator.legalValues.ints, v.Int()) { + state.exactness = loose + } +} + +func (state *decoderState) validateBool(v reflect.Value) { + if state.validator == nil { + return + } + b := v.Bool() + if state.validator.legalValues.bools == 1 && b == false { + state.exactness = loose + } else if state.validator.legalValues.bools == 0 && b == true { + state.exactness = loose + } +} + +func (state *decoderState) validateOptKind(node gjson.Result, t reflect.Type) { + switch node.Type { + case gjson.JSON: + state.exactness = loose + case gjson.Null: + return + case gjson.False, gjson.True: + if t.Kind() != reflect.Bool { + state.exactness = loose + } + case gjson.Number: + switch t.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Float32, reflect.Float64: + return + default: + state.exactness = loose + } + case gjson.String: + if t.Kind() != reflect.String { + state.exactness = loose + } + } +} diff --git a/internal/apijson/enum_test.go b/internal/apijson/enum_test.go new file mode 100644 index 0000000..a2aeed4 --- /dev/null +++ b/internal/apijson/enum_test.go @@ -0,0 +1,87 @@ +package apijson + +import ( + "reflect" + "testing" +) + +type EnumStruct struct { + NormalString string `json:"normal_string"` + StringEnum string `json:"string_enum"` + NamedEnum NamedEnumType `json:"named_enum"` + + IntEnum int `json:"int_enum"` + BoolEnum bool `json:"bool_enum"` + + WeirdBoolEnum bool `json:"weird_bool_enum"` +} + +func (o *EnumStruct) UnmarshalJSON(data []byte) error { + return UnmarshalRoot(data, o) +} + +func init() { + RegisterFieldValidator[EnumStruct]("string_enum", "one", "two", "three") + RegisterFieldValidator[EnumStruct]("int_enum", 200, 404) + RegisterFieldValidator[EnumStruct]("bool_enum", false) + RegisterFieldValidator[EnumStruct]("weird_bool_enum", true, false) +} + +type NamedEnumType string + +const ( + NamedEnumOne NamedEnumType = "one" + NamedEnumTwo NamedEnumType = "two" + NamedEnumThree NamedEnumType = "three" +) + +func (e NamedEnumType) IsKnown() bool { + return e == NamedEnumOne || e == NamedEnumTwo || e == NamedEnumThree +} + +func TestEnumStructStringValidator(t *testing.T) { + cases := map[string]struct { + exactness + EnumStruct + }{ + `{"string_enum":"one"}`: {exact, EnumStruct{StringEnum: "one"}}, + `{"string_enum":"two"}`: {exact, EnumStruct{StringEnum: "two"}}, + `{"string_enum":"three"}`: {exact, EnumStruct{StringEnum: "three"}}, + `{"string_enum":"none"}`: {loose, EnumStruct{StringEnum: "none"}}, + `{"int_enum":200}`: {exact, EnumStruct{IntEnum: 200}}, + `{"int_enum":404}`: {exact, EnumStruct{IntEnum: 404}}, + `{"int_enum":500}`: {loose, EnumStruct{IntEnum: 500}}, + `{"bool_enum":false}`: {exact, EnumStruct{BoolEnum: false}}, + `{"bool_enum":true}`: {loose, EnumStruct{BoolEnum: true}}, + `{"weird_bool_enum":true}`: {exact, EnumStruct{WeirdBoolEnum: true}}, + `{"weird_bool_enum":false}`: {exact, EnumStruct{WeirdBoolEnum: false}}, + + `{"named_enum":"one"}`: {exact, EnumStruct{NamedEnum: NamedEnumOne}}, + `{"named_enum":"none"}`: {loose, EnumStruct{NamedEnum: "none"}}, + + `{"string_enum":"one","named_enum":"one"}`: {exact, EnumStruct{NamedEnum: "one", StringEnum: "one"}}, + `{"string_enum":"four","named_enum":"one"}`: { + loose, + EnumStruct{NamedEnum: "one", StringEnum: "four"}, + }, + `{"string_enum":"one","named_enum":"four"}`: { + loose, EnumStruct{NamedEnum: "four", StringEnum: "one"}, + }, + `{"wrong_key":"one"}`: {extras, EnumStruct{StringEnum: ""}}, + } + + for raw, expected := range cases { + var dst EnumStruct + + dec := decoderBuilder{root: true} + exactness, _ := dec.unmarshalWithExactness([]byte(raw), &dst) + + if !reflect.DeepEqual(dst, expected.EnumStruct) { + t.Fatalf("failed equality check %#v", dst) + } + + if exactness != expected.exactness { + t.Fatalf("exactness got %d expected %d %s", exactness, expected.exactness, raw) + } + } +} diff --git a/internal/apijson/field.go b/internal/apijson/field.go new file mode 100644 index 0000000..854d6dd --- /dev/null +++ b/internal/apijson/field.go @@ -0,0 +1,23 @@ +package apijson + +type status uint8 + +const ( + missing status = iota + null + invalid + valid +) + +type Field struct { + raw string + status status +} + +// Returns true if the field is explicitly `null` _or_ if it is not present at all (ie, missing). +// To check if the field's key is present in the JSON with an explicit null value, +// you must check `f.IsNull() && !f.IsMissing()`. +func (j Field) IsNull() bool { return j.status <= null } +func (j Field) IsMissing() bool { return j.status == missing } +func (j Field) IsInvalid() bool { return j.status == invalid } +func (j Field) Raw() string { return j.raw } diff --git a/internal/apijson/json_test.go b/internal/apijson/json_test.go new file mode 100644 index 0000000..02904d2 --- /dev/null +++ b/internal/apijson/json_test.go @@ -0,0 +1,616 @@ +package apijson + +import ( + "reflect" + "strings" + "testing" + "time" + + "github.com/tidwall/gjson" +) + +func P[T any](v T) *T { return &v } + +type Primitives struct { + A bool `json:"a"` + B int `json:"b"` + C uint `json:"c"` + D float64 `json:"d"` + E float32 `json:"e"` + F []int `json:"f"` +} + +type PrimitivePointers struct { + A *bool `json:"a"` + B *int `json:"b"` + C *uint `json:"c"` + D *float64 `json:"d"` + E *float32 `json:"e"` + F *[]int `json:"f"` +} + +type Slices struct { + Slice []Primitives `json:"slices"` +} + +type DateTime struct { + Date time.Time `json:"date" format:"date"` + DateTime time.Time `json:"date-time" format:"date-time"` +} + +type AdditionalProperties struct { + A bool `json:"a"` + ExtraFields map[string]any `json:"-,extras"` +} + +type TypedAdditionalProperties struct { + A bool `json:"a"` + ExtraFields map[string]int `json:"-,extras"` +} + +type EmbeddedStruct struct { + A bool `json:"a"` + B string `json:"b"` + + JSON EmbeddedStructJSON +} + +type EmbeddedStructJSON struct { + A Field + B Field + ExtraFields map[string]Field + raw string +} + +type EmbeddedStructs struct { + EmbeddedStruct + A *int `json:"a"` + ExtraFields map[string]any `json:"-,extras"` + + JSON EmbeddedStructsJSON +} + +type EmbeddedStructsJSON struct { + A Field + ExtraFields map[string]Field + raw string +} + +type Recursive struct { + Name string `json:"name"` + Child *Recursive `json:"child"` +} + +type JSONFieldStruct struct { + A bool `json:"a"` + B int64 `json:"b"` + C string `json:"c"` + D string `json:"d"` + ExtraFields map[string]int64 `json:",extras"` + JSON JSONFieldStructJSON `json:",metadata"` +} + +type JSONFieldStructJSON struct { + A Field + B Field + C Field + D Field + ExtraFields map[string]Field + raw string +} + +type UnknownStruct struct { + Unknown any `json:"unknown"` +} + +type UnionStruct struct { + Union Union `json:"union" format:"date"` +} + +type Union interface { + union() +} + +type Inline struct { + InlineField Primitives `json:",inline"` + JSON InlineJSON `json:",metadata"` +} + +type InlineArray struct { + InlineField []string `json:",inline"` + JSON InlineJSON `json:",metadata"` +} + +type InlineJSON struct { + InlineField Field + raw string +} + +type UnionInteger int64 + +func (UnionInteger) union() {} + +type UnionStructA struct { + Type string `json:"type"` + A string `json:"a"` + B string `json:"b"` +} + +func (UnionStructA) union() {} + +type UnionStructB struct { + Type string `json:"type"` + A string `json:"a"` +} + +func (UnionStructB) union() {} + +type UnionTime time.Time + +func (UnionTime) union() {} + +func init() { + RegisterUnion[Union]("type", + UnionVariant{ + TypeFilter: gjson.String, + Type: reflect.TypeOf(UnionTime{}), + }, + UnionVariant{ + TypeFilter: gjson.Number, + Type: reflect.TypeOf(UnionInteger(0)), + }, + UnionVariant{ + TypeFilter: gjson.JSON, + DiscriminatorValue: "typeA", + Type: reflect.TypeOf(UnionStructA{}), + }, + UnionVariant{ + TypeFilter: gjson.JSON, + DiscriminatorValue: "typeB", + Type: reflect.TypeOf(UnionStructB{}), + }, + ) +} + +type ComplexUnionStruct struct { + Union ComplexUnion `json:"union"` +} + +type ComplexUnion interface { + complexUnion() +} + +type ComplexUnionA struct { + Boo string `json:"boo"` + Foo bool `json:"foo"` +} + +func (ComplexUnionA) complexUnion() {} + +type ComplexUnionB struct { + Boo bool `json:"boo"` + Foo string `json:"foo"` +} + +func (ComplexUnionB) complexUnion() {} + +type ComplexUnionC struct { + Boo int64 `json:"boo"` +} + +func (ComplexUnionC) complexUnion() {} + +type ComplexUnionTypeA struct { + Baz int64 `json:"baz"` + Type TypeA `json:"type"` +} + +func (ComplexUnionTypeA) complexUnion() {} + +type TypeA string + +func (t TypeA) IsKnown() bool { + return t == "a" +} + +type ComplexUnionTypeB struct { + Baz int64 `json:"baz"` + Type TypeB `json:"type"` +} + +type TypeB string + +func (t TypeB) IsKnown() bool { + return t == "b" +} + +type UnmarshalStruct struct { + Foo string `json:"foo"` + prop bool `json:"-"` +} + +func (r *UnmarshalStruct) UnmarshalJSON(json []byte) error { + r.prop = true + return UnmarshalRoot(json, r) +} + +func (ComplexUnionTypeB) complexUnion() {} + +func init() { + RegisterUnion[ComplexUnion]("", + UnionVariant{ + TypeFilter: gjson.JSON, + Type: reflect.TypeOf(ComplexUnionA{}), + }, + UnionVariant{ + TypeFilter: gjson.JSON, + Type: reflect.TypeOf(ComplexUnionB{}), + }, + UnionVariant{ + TypeFilter: gjson.JSON, + Type: reflect.TypeOf(ComplexUnionC{}), + }, + UnionVariant{ + TypeFilter: gjson.JSON, + Type: reflect.TypeOf(ComplexUnionTypeA{}), + }, + UnionVariant{ + TypeFilter: gjson.JSON, + Type: reflect.TypeOf(ComplexUnionTypeB{}), + }, + ) +} + +type MarshallingUnionStruct struct { + Union MarshallingUnion +} + +func (r *MarshallingUnionStruct) UnmarshalJSON(data []byte) (err error) { + *r = MarshallingUnionStruct{} + err = UnmarshalRoot(data, &r.Union) + return +} + +func (r MarshallingUnionStruct) MarshalJSON() (data []byte, err error) { + return MarshalRoot(r.Union) +} + +type MarshallingUnion interface { + marshallingUnion() +} + +type MarshallingUnionA struct { + Boo string `json:"boo"` +} + +func (MarshallingUnionA) marshallingUnion() {} + +func (r *MarshallingUnionA) UnmarshalJSON(data []byte) (err error) { + return UnmarshalRoot(data, r) +} + +type MarshallingUnionB struct { + Foo string `json:"foo"` +} + +func (MarshallingUnionB) marshallingUnion() {} + +func (r *MarshallingUnionB) UnmarshalJSON(data []byte) (err error) { + return UnmarshalRoot(data, r) +} + +func init() { + RegisterUnion[MarshallingUnion]( + "", + UnionVariant{ + TypeFilter: gjson.JSON, + Type: reflect.TypeOf(MarshallingUnionA{}), + }, + UnionVariant{ + TypeFilter: gjson.JSON, + Type: reflect.TypeOf(MarshallingUnionB{}), + }, + ) +} + +var tests = map[string]struct { + buf string + val any +}{ + "true": {"true", true}, + "false": {"false", false}, + "int": {"1", 1}, + "int_bigger": {"12324", 12324}, + "int_string_coerce": {`"65"`, 65}, + "int_boolean_coerce": {"true", 1}, + "int64": {"1", int64(1)}, + "int64_huge": {"123456789123456789", int64(123456789123456789)}, + "uint": {"1", uint(1)}, + "uint_bigger": {"12324", uint(12324)}, + "uint_coerce": {`"65"`, uint(65)}, + "float_1.54": {"1.54", float32(1.54)}, + "float_1.89": {"1.89", float64(1.89)}, + "string": {`"str"`, "str"}, + "string_int_coerce": {`12`, "12"}, + "array_string": {`["foo","bar"]`, []string{"foo", "bar"}}, + "array_int": {`[1,2]`, []int{1, 2}}, + "array_int_coerce": {`["1",2]`, []int{1, 2}}, + + "ptr_true": {"true", P(true)}, + "ptr_false": {"false", P(false)}, + "ptr_int": {"1", P(1)}, + "ptr_int_bigger": {"12324", P(12324)}, + "ptr_int_string_coerce": {`"65"`, P(65)}, + "ptr_int_boolean_coerce": {"true", P(1)}, + "ptr_int64": {"1", P(int64(1))}, + "ptr_int64_huge": {"123456789123456789", P(int64(123456789123456789))}, + "ptr_uint": {"1", P(uint(1))}, + "ptr_uint_bigger": {"12324", P(uint(12324))}, + "ptr_uint_coerce": {`"65"`, P(uint(65))}, + "ptr_float_1.54": {"1.54", P(float32(1.54))}, + "ptr_float_1.89": {"1.89", P(float64(1.89))}, + + "date_time": {`"2007-03-01T13:00:00Z"`, time.Date(2007, time.March, 1, 13, 0, 0, 0, time.UTC)}, + "date_time_nano_coerce": {`"2007-03-01T13:03:05.123456789Z"`, time.Date(2007, time.March, 1, 13, 3, 5, 123456789, time.UTC)}, + + "date_time_missing_t_coerce": {`"2007-03-01 13:03:05Z"`, time.Date(2007, time.March, 1, 13, 3, 5, 0, time.UTC)}, + "date_time_missing_timezone_coerce": {`"2007-03-01T13:03:05"`, time.Date(2007, time.March, 1, 13, 3, 5, 0, time.UTC)}, + // note: using -1200 to minimize probability of conflicting with the local timezone of the test runner + // see https://en.wikipedia.org/wiki/UTC%E2%88%9212:00 + "date_time_missing_timezone_colon_coerce": {`"2007-03-01T13:03:05-1200"`, time.Date(2007, time.March, 1, 13, 3, 5, 0, time.FixedZone("", -12*60*60))}, + "date_time_nano_missing_t_coerce": {`"2007-03-01 13:03:05.123456789Z"`, time.Date(2007, time.March, 1, 13, 3, 5, 123456789, time.UTC)}, + + "map_string": {`{"foo":"bar"}`, map[string]string{"foo": "bar"}}, + "map_string_with_sjson_path_chars": {`{":a.b.c*:d*-1e.f":"bar"}`, map[string]string{":a.b.c*:d*-1e.f": "bar"}}, + "map_interface": {`{"a":1,"b":"str","c":false}`, map[string]any{"a": float64(1), "b": "str", "c": false}}, + + "primitive_struct": { + `{"a":false,"b":237628372683,"c":654,"d":9999.43,"e":43.76,"f":[1,2,3,4]}`, + Primitives{A: false, B: 237628372683, C: uint(654), D: 9999.43, E: 43.76, F: []int{1, 2, 3, 4}}, + }, + + "slices": { + `{"slices":[{"a":false,"b":237628372683,"c":654,"d":9999.43,"e":43.76,"f":[1,2,3,4]}]}`, + Slices{ + Slice: []Primitives{{A: false, B: 237628372683, C: uint(654), D: 9999.43, E: 43.76, F: []int{1, 2, 3, 4}}}, + }, + }, + + "primitive_pointer_struct": { + `{"a":false,"b":237628372683,"c":654,"d":9999.43,"e":43.76,"f":[1,2,3,4,5]}`, + PrimitivePointers{ + A: P(false), + B: P(237628372683), + C: P(uint(654)), + D: P(9999.43), + E: P(float32(43.76)), + F: &[]int{1, 2, 3, 4, 5}, + }, + }, + + "datetime_struct": { + `{"date":"2006-01-02","date-time":"2006-01-02T15:04:05Z"}`, + DateTime{ + Date: time.Date(2006, time.January, 2, 0, 0, 0, 0, time.UTC), + DateTime: time.Date(2006, time.January, 2, 15, 4, 5, 0, time.UTC), + }, + }, + + "additional_properties": { + `{"a":true,"bar":"value","foo":true}`, + AdditionalProperties{ + A: true, + ExtraFields: map[string]any{ + "bar": "value", + "foo": true, + }, + }, + }, + + "embedded_struct": { + `{"a":1,"b":"bar"}`, + EmbeddedStructs{ + EmbeddedStruct: EmbeddedStruct{ + A: true, + B: "bar", + JSON: EmbeddedStructJSON{ + A: Field{raw: `1`, status: valid}, + B: Field{raw: `"bar"`, status: valid}, + raw: `{"a":1,"b":"bar"}`, + }, + }, + A: P(1), + ExtraFields: map[string]any{"b": "bar"}, + JSON: EmbeddedStructsJSON{ + A: Field{raw: `1`, status: valid}, + ExtraFields: map[string]Field{ + "b": {raw: `"bar"`, status: valid}, + }, + raw: `{"a":1,"b":"bar"}`, + }, + }, + }, + + "recursive_struct": { + `{"child":{"name":"Alex"},"name":"Robert"}`, + Recursive{Name: "Robert", Child: &Recursive{Name: "Alex"}}, + }, + + "metadata_coerce": { + `{"a":"12","b":"12","c":null,"extra_typed":12,"extra_untyped":{"foo":"bar"}}`, + JSONFieldStruct{ + A: false, + B: 12, + C: "", + JSON: JSONFieldStructJSON{ + raw: `{"a":"12","b":"12","c":null,"extra_typed":12,"extra_untyped":{"foo":"bar"}}`, + A: Field{raw: `"12"`, status: invalid}, + B: Field{raw: `"12"`, status: valid}, + C: Field{raw: "null", status: null}, + D: Field{raw: "", status: missing}, + ExtraFields: map[string]Field{ + "extra_typed": { + raw: "12", + status: valid, + }, + "extra_untyped": { + raw: `{"foo":"bar"}`, + status: invalid, + }, + }, + }, + ExtraFields: map[string]int64{ + "extra_typed": 12, + "extra_untyped": 0, + }, + }, + }, + + "unknown_struct_number": { + `{"unknown":12}`, + UnknownStruct{ + Unknown: 12., + }, + }, + + "unknown_struct_map": { + `{"unknown":{"foo":"bar"}}`, + UnknownStruct{ + Unknown: map[string]any{ + "foo": "bar", + }, + }, + }, + + "union_integer": { + `{"union":12}`, + UnionStruct{ + Union: UnionInteger(12), + }, + }, + + "union_struct_discriminated_a": { + `{"union":{"a":"foo","b":"bar","type":"typeA"}}`, + UnionStruct{ + Union: UnionStructA{ + Type: "typeA", + A: "foo", + B: "bar", + }, + }, + }, + + "union_struct_discriminated_b": { + `{"union":{"a":"foo","type":"typeB"}}`, + UnionStruct{ + Union: UnionStructB{ + Type: "typeB", + A: "foo", + }, + }, + }, + + "union_struct_time": { + `{"union":"2010-05-23"}`, + UnionStruct{ + Union: UnionTime(time.Date(2010, 05, 23, 0, 0, 0, 0, time.UTC)), + }, + }, + + "complex_union_a": { + `{"union":{"boo":"12","foo":true}}`, + ComplexUnionStruct{Union: ComplexUnionA{Boo: "12", Foo: true}}, + }, + + "complex_union_b": { + `{"union":{"boo":true,"foo":"12"}}`, + ComplexUnionStruct{Union: ComplexUnionB{Boo: true, Foo: "12"}}, + }, + + "complex_union_c": { + `{"union":{"boo":12}}`, + ComplexUnionStruct{Union: ComplexUnionC{Boo: 12}}, + }, + + "complex_union_type_a": { + `{"union":{"baz":12,"type":"a"}}`, + ComplexUnionStruct{Union: ComplexUnionTypeA{Baz: 12, Type: TypeA("a")}}, + }, + + "complex_union_type_b": { + `{"union":{"baz":12,"type":"b"}}`, + ComplexUnionStruct{Union: ComplexUnionTypeB{Baz: 12, Type: TypeB("b")}}, + }, + + "marshalling_union_a": { + `{"boo":"hello"}`, + MarshallingUnionStruct{Union: MarshallingUnionA{Boo: "hello"}}, + }, + "marshalling_union_b": { + `{"foo":"hi"}`, + MarshallingUnionStruct{Union: MarshallingUnionB{Foo: "hi"}}, + }, + + "unmarshal": { + `{"foo":"hello"}`, + &UnmarshalStruct{Foo: "hello", prop: true}, + }, + + "array_of_unmarshal": { + `[{"foo":"hello"}]`, + []UnmarshalStruct{{Foo: "hello", prop: true}}, + }, + + "inline_coerce": { + `{"a":false,"b":237628372683,"c":654,"d":9999.43,"e":43.76,"f":[1,2,3,4]}`, + Inline{ + InlineField: Primitives{A: false, B: 237628372683, C: 0x28e, D: 9999.43, E: 43.76, F: []int{1, 2, 3, 4}}, + JSON: InlineJSON{ + InlineField: Field{raw: "{\"a\":false,\"b\":237628372683,\"c\":654,\"d\":9999.43,\"e\":43.76,\"f\":[1,2,3,4]}", status: 3}, + raw: "{\"a\":false,\"b\":237628372683,\"c\":654,\"d\":9999.43,\"e\":43.76,\"f\":[1,2,3,4]}", + }, + }, + }, + + "inline_array_coerce": { + `["Hello","foo","bar"]`, + InlineArray{ + InlineField: []string{"Hello", "foo", "bar"}, + JSON: InlineJSON{ + InlineField: Field{raw: `["Hello","foo","bar"]`, status: 3}, + raw: `["Hello","foo","bar"]`, + }, + }, + }, +} + +func TestDecode(t *testing.T) { + for name, test := range tests { + t.Run(name, func(t *testing.T) { + result := reflect.New(reflect.TypeOf(test.val)) + if err := Unmarshal([]byte(test.buf), result.Interface()); err != nil { + t.Fatalf("deserialization of %v failed with error %v", result, err) + } + if !reflect.DeepEqual(result.Elem().Interface(), test.val) { + t.Fatalf("expected '%s' to deserialize to \n%#v\nbut got\n%#v", test.buf, test.val, result.Elem().Interface()) + } + }) + } +} + +func TestEncode(t *testing.T) { + for name, test := range tests { + if strings.HasSuffix(name, "_coerce") { + continue + } + t.Run(name, func(t *testing.T) { + raw, err := Marshal(test.val) + if err != nil { + t.Fatalf("serialization of %v failed with error %v", test.val, err) + } + if string(raw) != test.buf { + t.Fatalf("expected %+#v to serialize to %s but got %s", test.val, test.buf, string(raw)) + } + }) + } +} diff --git a/internal/apijson/port.go b/internal/apijson/port.go new file mode 100644 index 0000000..b40013c --- /dev/null +++ b/internal/apijson/port.go @@ -0,0 +1,120 @@ +package apijson + +import ( + "fmt" + "reflect" +) + +// Port copies over values from one struct to another struct. +func Port(from any, to any) error { + toVal := reflect.ValueOf(to) + fromVal := reflect.ValueOf(from) + + if toVal.Kind() != reflect.Ptr || toVal.IsNil() { + return fmt.Errorf("destination must be a non-nil pointer") + } + + for toVal.Kind() == reflect.Ptr { + toVal = toVal.Elem() + } + toType := toVal.Type() + + for fromVal.Kind() == reflect.Ptr { + fromVal = fromVal.Elem() + } + fromType := fromVal.Type() + + if toType.Kind() != reflect.Struct { + return fmt.Errorf("destination must be a non-nil pointer to a struct (%v %v)", toType, toType.Kind()) + } + + values := map[string]reflect.Value{} + fields := map[string]reflect.Value{} + + fromJSON := fromVal.FieldByName("JSON") + toJSON := toVal.FieldByName("JSON") + + // Iterate through the fields of v and load all the "normal" fields in the struct to the map of + // string to reflect.Value, as well as their raw .JSON.Foo counterpart indicated by j. + var getFields func(t reflect.Type, v reflect.Value) + getFields = func(t reflect.Type, v reflect.Value) { + j := v.FieldByName("JSON") + + // Recurse into anonymous fields first, since the fields on the object should win over the fields in the + // embedded object. + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + if field.Anonymous { + getFields(field.Type, v.Field(i)) + continue + } + } + + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + ptag, ok := parseJSONStructTag(field) + if !ok || ptag.name == "-" || ptag.name == "" { + continue + } + values[ptag.name] = v.Field(i) + if j.IsValid() { + fields[ptag.name] = j.FieldByName(field.Name) + } + } + } + getFields(fromType, fromVal) + + // Use the values from the previous step to populate the 'to' struct. + for i := 0; i < toType.NumField(); i++ { + field := toType.Field(i) + ptag, ok := parseJSONStructTag(field) + if !ok { + continue + } + if ptag.name == "-" { + continue + } + if value, ok := values[ptag.name]; ok { + delete(values, ptag.name) + if field.Type.Kind() == reflect.Interface { + toVal.Field(i).Set(value) + } else { + switch value.Kind() { + case reflect.String: + toVal.Field(i).SetString(value.String()) + case reflect.Bool: + toVal.Field(i).SetBool(value.Bool()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + toVal.Field(i).SetInt(value.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + toVal.Field(i).SetUint(value.Uint()) + case reflect.Float32, reflect.Float64: + toVal.Field(i).SetFloat(value.Float()) + default: + toVal.Field(i).Set(value) + } + } + } + + if fromJSONField, ok := fields[ptag.name]; ok { + if toJSONField := toJSON.FieldByName(field.Name); toJSONField.IsValid() { + toJSONField.Set(fromJSONField) + } + } + } + + // Finally, copy over the .JSON.raw and .JSON.ExtraFields + if toJSON.IsValid() { + if raw := toJSON.FieldByName("raw"); raw.IsValid() { + setUnexportedField(raw, fromJSON.Interface().(interface{ RawJSON() string }).RawJSON()) + } + + if toExtraFields := toJSON.FieldByName("ExtraFields"); toExtraFields.IsValid() { + if fromExtraFields := fromJSON.FieldByName("ExtraFields"); fromExtraFields.IsValid() { + setUnexportedField(toExtraFields, fromExtraFields.Interface()) + } + } + } + + return nil +} diff --git a/internal/apijson/port_test.go b/internal/apijson/port_test.go new file mode 100644 index 0000000..bb01f1a --- /dev/null +++ b/internal/apijson/port_test.go @@ -0,0 +1,257 @@ +package apijson + +import ( + "reflect" + "testing" +) + +type Metadata struct { + CreatedAt string `json:"created_at"` +} + +// Card is the "combined" type of CardVisa and CardMastercard +type Card struct { + Processor CardProcessor `json:"processor"` + Data any `json:"data"` + IsFoo bool `json:"is_foo"` + IsBar bool `json:"is_bar"` + Metadata Metadata `json:"metadata"` + Value any `json:"value"` + + JSON cardJSON +} + +type cardJSON struct { + Processor Field + Data Field + IsFoo Field + IsBar Field + Metadata Field + Value Field + ExtraFields map[string]Field + raw string +} + +func (r cardJSON) RawJSON() string { return r.raw } + +type CardProcessor string + +// CardVisa +type CardVisa struct { + Processor CardVisaProcessor `json:"processor"` + Data CardVisaData `json:"data"` + IsFoo bool `json:"is_foo"` + Metadata Metadata `json:"metadata"` + Value string `json:"value"` + + JSON cardVisaJSON +} + +type cardVisaJSON struct { + Processor Field + Data Field + IsFoo Field + Metadata Field + Value Field + ExtraFields map[string]Field + raw string +} + +func (r cardVisaJSON) RawJSON() string { return r.raw } + +type CardVisaProcessor string + +type CardVisaData struct { + Foo string `json:"foo"` +} + +// CardMastercard +type CardMastercard struct { + Processor CardMastercardProcessor `json:"processor"` + Data CardMastercardData `json:"data"` + IsBar bool `json:"is_bar"` + Metadata Metadata `json:"metadata"` + Value bool `json:"value"` + + JSON cardMastercardJSON +} + +type cardMastercardJSON struct { + Processor Field + Data Field + IsBar Field + Metadata Field + Value Field + ExtraFields map[string]Field + raw string +} + +func (r cardMastercardJSON) RawJSON() string { return r.raw } + +type CardMastercardProcessor string + +type CardMastercardData struct { + Bar int64 `json:"bar"` +} + +type CommonFields struct { + Metadata Metadata `json:"metadata"` + Value string `json:"value"` + + JSON commonFieldsJSON +} + +type commonFieldsJSON struct { + Metadata Field + Value Field + ExtraFields map[string]Field + raw string +} + +type CardEmbedded struct { + CommonFields + Processor CardVisaProcessor `json:"processor"` + Data CardVisaData `json:"data"` + IsFoo bool `json:"is_foo"` + + JSON cardEmbeddedJSON +} + +type cardEmbeddedJSON struct { + Processor Field + Data Field + IsFoo Field + ExtraFields map[string]Field + raw string +} + +func (r cardEmbeddedJSON) RawJSON() string { return r.raw } + +var portTests = map[string]struct { + from any + to any +}{ + "visa to card": { + CardVisa{ + Processor: "visa", + IsFoo: true, + Data: CardVisaData{ + Foo: "foo", + }, + Metadata: Metadata{ + CreatedAt: "Mar 29 2024", + }, + Value: "value", + JSON: cardVisaJSON{ + raw: `{"processor":"visa","is_foo":true,"data":{"foo":"foo"}}`, + Processor: Field{raw: `"visa"`, status: valid}, + IsFoo: Field{raw: `true`, status: valid}, + Data: Field{raw: `{"foo":"foo"}`, status: valid}, + Value: Field{raw: `"value"`, status: valid}, + ExtraFields: map[string]Field{"extra": {raw: `"yo"`, status: valid}}, + }, + }, + Card{ + Processor: "visa", + IsFoo: true, + IsBar: false, + Data: CardVisaData{ + Foo: "foo", + }, + Metadata: Metadata{ + CreatedAt: "Mar 29 2024", + }, + Value: "value", + JSON: cardJSON{ + raw: `{"processor":"visa","is_foo":true,"data":{"foo":"foo"}}`, + Processor: Field{raw: `"visa"`, status: valid}, + IsFoo: Field{raw: `true`, status: valid}, + Data: Field{raw: `{"foo":"foo"}`, status: valid}, + Value: Field{raw: `"value"`, status: valid}, + ExtraFields: map[string]Field{"extra": {raw: `"yo"`, status: valid}}, + }, + }, + }, + "mastercard to card": { + CardMastercard{ + Processor: "mastercard", + IsBar: true, + Data: CardMastercardData{ + Bar: 13, + }, + Value: false, + }, + Card{ + Processor: "mastercard", + IsFoo: false, + IsBar: true, + Data: CardMastercardData{ + Bar: 13, + }, + Value: false, + }, + }, + "embedded to card": { + CardEmbedded{ + CommonFields: CommonFields{ + Metadata: Metadata{ + CreatedAt: "Mar 29 2024", + }, + Value: "embedded_value", + JSON: commonFieldsJSON{ + Metadata: Field{raw: `{"created_at":"Mar 29 2024"}`, status: valid}, + Value: Field{raw: `"embedded_value"`, status: valid}, + raw: `should not matter`, + }, + }, + Processor: "visa", + IsFoo: true, + Data: CardVisaData{ + Foo: "embedded_foo", + }, + JSON: cardEmbeddedJSON{ + raw: `{"processor":"visa","is_foo":true,"data":{"foo":"embedded_foo"},"metadata":{"created_at":"Mar 29 2024"},"value":"embedded_value"}`, + Processor: Field{raw: `"visa"`, status: valid}, + IsFoo: Field{raw: `true`, status: valid}, + Data: Field{raw: `{"foo":"embedded_foo"}`, status: valid}, + }, + }, + Card{ + Processor: "visa", + IsFoo: true, + IsBar: false, + Data: CardVisaData{ + Foo: "embedded_foo", + }, + Metadata: Metadata{ + CreatedAt: "Mar 29 2024", + }, + Value: "embedded_value", + JSON: cardJSON{ + raw: `{"processor":"visa","is_foo":true,"data":{"foo":"embedded_foo"},"metadata":{"created_at":"Mar 29 2024"},"value":"embedded_value"}`, + Processor: Field{raw: `"visa"`, status: 0x3}, + IsFoo: Field{raw: "true", status: 0x3}, + Data: Field{raw: `{"foo":"embedded_foo"}`, status: 0x3}, + Metadata: Field{raw: `{"created_at":"Mar 29 2024"}`, status: 0x3}, + Value: Field{raw: `"embedded_value"`, status: 0x3}, + }, + }, + }, +} + +func TestPort(t *testing.T) { + for name, test := range portTests { + t.Run(name, func(t *testing.T) { + toVal := reflect.New(reflect.TypeOf(test.to)) + + err := Port(test.from, toVal.Interface()) + if err != nil { + t.Fatalf("port of %v failed with error %v", test.from, err) + } + + if !reflect.DeepEqual(toVal.Elem().Interface(), test.to) { + t.Fatalf("expected:\n%+#v\n\nto port to:\n%+#v\n\nbut got:\n%+#v", test.from, test.to, toVal.Elem().Interface()) + } + }) + } +} diff --git a/internal/apijson/registry.go b/internal/apijson/registry.go new file mode 100644 index 0000000..2a24982 --- /dev/null +++ b/internal/apijson/registry.go @@ -0,0 +1,51 @@ +package apijson + +import ( + "reflect" + + "github.com/tidwall/gjson" +) + +type UnionVariant struct { + TypeFilter gjson.Type + DiscriminatorValue any + Type reflect.Type +} + +var unionRegistry = map[reflect.Type]unionEntry{} +var unionVariants = map[reflect.Type]any{} + +type unionEntry struct { + discriminatorKey string + variants []UnionVariant +} + +func Discriminator[T any](value any) UnionVariant { + var zero T + return UnionVariant{ + TypeFilter: gjson.JSON, + DiscriminatorValue: value, + Type: reflect.TypeOf(zero), + } +} + +func RegisterUnion[T any](discriminator string, variants ...UnionVariant) { + typ := reflect.TypeOf((*T)(nil)).Elem() + unionRegistry[typ] = unionEntry{ + discriminatorKey: discriminator, + variants: variants, + } + for _, variant := range variants { + unionVariants[variant.Type] = typ + } +} + +// Useful to wrap a union type to force it to use [apijson.UnmarshalJSON] since you cannot define an +// UnmarshalJSON function on the interface itself. +type UnionUnmarshaler[T any] struct { + Value T +} + +func (c *UnionUnmarshaler[T]) UnmarshalJSON(buf []byte) error { + return UnmarshalRoot(buf, &c.Value) +} diff --git a/internal/apijson/subfield.go b/internal/apijson/subfield.go new file mode 100644 index 0000000..65ca2de --- /dev/null +++ b/internal/apijson/subfield.go @@ -0,0 +1,67 @@ +package apijson + +import ( + "github.com/ScrapeGraphAI/scrapegraph-sdk/packages/respjson" + "reflect" +) + +func getSubField(root reflect.Value, index []int, name string) reflect.Value { + strct := root.FieldByIndex(index[:len(index)-1]) + if !strct.IsValid() { + panic("couldn't find encapsulating struct for field " + name) + } + meta := strct.FieldByName("JSON") + if !meta.IsValid() { + return reflect.Value{} + } + field := meta.FieldByName(name) + if !field.IsValid() { + return reflect.Value{} + } + return field +} + +func setMetadataSubField(root reflect.Value, index []int, name string, meta Field) { + target := getSubField(root, index, name) + if !target.IsValid() { + return + } + + if target.Type() == reflect.TypeOf(meta) { + target.Set(reflect.ValueOf(meta)) + } else if respMeta := meta.toRespField(); target.Type() == reflect.TypeOf(respMeta) { + target.Set(reflect.ValueOf(respMeta)) + } +} + +func setMetadataExtraFields(root reflect.Value, index []int, name string, metaExtras map[string]Field) { + target := getSubField(root, index, name) + if !target.IsValid() { + return + } + + if target.Type() == reflect.TypeOf(metaExtras) { + target.Set(reflect.ValueOf(metaExtras)) + return + } + + newMap := make(map[string]respjson.Field, len(metaExtras)) + if target.Type() == reflect.TypeOf(newMap) { + for k, v := range metaExtras { + newMap[k] = v.toRespField() + } + target.Set(reflect.ValueOf(newMap)) + } +} + +func (f Field) toRespField() respjson.Field { + if f.IsMissing() { + return respjson.Field{} + } else if f.IsNull() { + return respjson.NewField("null") + } else if f.IsInvalid() { + return respjson.NewInvalidField(f.raw) + } else { + return respjson.NewField(f.raw) + } +} diff --git a/internal/apijson/tag.go b/internal/apijson/tag.go new file mode 100644 index 0000000..812fb3c --- /dev/null +++ b/internal/apijson/tag.go @@ -0,0 +1,47 @@ +package apijson + +import ( + "reflect" + "strings" +) + +const jsonStructTag = "json" +const formatStructTag = "format" + +type parsedStructTag struct { + name string + required bool + extras bool + metadata bool + inline bool +} + +func parseJSONStructTag(field reflect.StructField) (tag parsedStructTag, ok bool) { + raw, ok := field.Tag.Lookup(jsonStructTag) + if !ok { + return + } + parts := strings.Split(raw, ",") + if len(parts) == 0 { + return tag, false + } + tag.name = parts[0] + for _, part := range parts[1:] { + switch part { + case "required": + tag.required = true + case "extras": + tag.extras = true + case "metadata": + tag.metadata = true + case "inline": + tag.inline = true + } + } + return +} + +func parseFormatStructTag(field reflect.StructField) (format string, ok bool) { + format, ok = field.Tag.Lookup(formatStructTag) + return +} diff --git a/internal/apijson/union.go b/internal/apijson/union.go new file mode 100644 index 0000000..a766238 --- /dev/null +++ b/internal/apijson/union.go @@ -0,0 +1,202 @@ +package apijson + +import ( + "errors" + "github.com/ScrapeGraphAI/scrapegraph-sdk/packages/param" + "reflect" + + "github.com/tidwall/gjson" +) + +var apiUnionType = reflect.TypeOf(param.APIUnion{}) + +func isStructUnion(t reflect.Type) bool { + if t.Kind() != reflect.Struct { + return false + } + for i := 0; i < t.NumField(); i++ { + if t.Field(i).Type == apiUnionType && t.Field(i).Anonymous { + return true + } + } + return false +} + +func RegisterDiscriminatedUnion[T any](key string, mappings map[string]reflect.Type) { + var t T + entry := unionEntry{ + discriminatorKey: key, + variants: []UnionVariant{}, + } + for k, typ := range mappings { + entry.variants = append(entry.variants, UnionVariant{ + DiscriminatorValue: k, + Type: typ, + }) + } + unionRegistry[reflect.TypeOf(t)] = entry +} + +func (d *decoderBuilder) newStructUnionDecoder(t reflect.Type) decoderFunc { + type variantDecoder struct { + decoder decoderFunc + field reflect.StructField + discriminatorValue any + } + + variants := []variantDecoder{} + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + + if field.Anonymous && field.Type == apiUnionType { + continue + } + + decoder := d.typeDecoder(field.Type) + variants = append(variants, variantDecoder{ + decoder: decoder, + field: field, + }) + } + + unionEntry, discriminated := unionRegistry[t] + for _, unionVariant := range unionEntry.variants { + for i := 0; i < len(variants); i++ { + variant := &variants[i] + if variant.field.Type.Elem() == unionVariant.Type { + variant.discriminatorValue = unionVariant.DiscriminatorValue + break + } + } + } + + return func(n gjson.Result, v reflect.Value, state *decoderState) error { + if discriminated && n.Type == gjson.JSON && len(unionEntry.discriminatorKey) != 0 { + discriminator := n.Get(unionEntry.discriminatorKey).Value() + for _, variant := range variants { + if discriminator == variant.discriminatorValue { + inner := v.FieldByIndex(variant.field.Index) + return variant.decoder(n, inner, state) + } + } + return errors.New("apijson: was not able to find discriminated union variant") + } + + // Set bestExactness to worse than loose + bestExactness := loose - 1 + bestVariant := -1 + for i, variant := range variants { + // Pointers are used to discern JSON object variants from value variants + if n.Type != gjson.JSON && variant.field.Type.Kind() == reflect.Ptr { + continue + } + + sub := decoderState{strict: state.strict, exactness: exact} + inner := v.FieldByIndex(variant.field.Index) + err := variant.decoder(n, inner, &sub) + if err != nil { + continue + } + if sub.exactness == exact { + bestExactness = exact + bestVariant = i + break + } + if sub.exactness > bestExactness { + bestExactness = sub.exactness + bestVariant = i + } + } + + if bestExactness < loose { + return errors.New("apijson: was not able to coerce type as union") + } + + if guardStrict(state, bestExactness != exact) { + return errors.New("apijson: was not able to coerce type as union strictly") + } + + for i := 0; i < len(variants); i++ { + if i == bestVariant { + continue + } + v.FieldByIndex(variants[i].field.Index).SetZero() + } + + return nil + } +} + +// newUnionDecoder returns a decoderFunc that deserializes into a union using an +// algorithm roughly similar to Pydantic's [smart algorithm]. +// +// Conceptually this is equivalent to choosing the best schema based on how 'exact' +// the deserialization is for each of the schemas. +// +// If there is a tie in the level of exactness, then the tie is broken +// left-to-right. +// +// [smart algorithm]: https://docs.pydantic.dev/latest/concepts/unions/#smart-mode +func (d *decoderBuilder) newUnionDecoder(t reflect.Type) decoderFunc { + unionEntry, ok := unionRegistry[t] + if !ok { + panic("apijson: couldn't find union of type " + t.String() + " in union registry") + } + decoders := []decoderFunc{} + for _, variant := range unionEntry.variants { + decoder := d.typeDecoder(variant.Type) + decoders = append(decoders, decoder) + } + return func(n gjson.Result, v reflect.Value, state *decoderState) error { + // If there is a discriminator match, circumvent the exactness logic entirely + for idx, variant := range unionEntry.variants { + decoder := decoders[idx] + if variant.TypeFilter != n.Type { + continue + } + + if len(unionEntry.discriminatorKey) != 0 { + discriminatorValue := n.Get(unionEntry.discriminatorKey).Value() + if discriminatorValue == variant.DiscriminatorValue { + inner := reflect.New(variant.Type).Elem() + err := decoder(n, inner, state) + v.Set(inner) + return err + } + } + } + + // Set bestExactness to worse than loose + bestExactness := loose - 1 + for idx, variant := range unionEntry.variants { + decoder := decoders[idx] + if variant.TypeFilter != n.Type { + continue + } + sub := decoderState{strict: state.strict, exactness: exact} + inner := reflect.New(variant.Type).Elem() + err := decoder(n, inner, &sub) + if err != nil { + continue + } + if sub.exactness == exact { + v.Set(inner) + return nil + } + if sub.exactness > bestExactness { + v.Set(inner) + bestExactness = sub.exactness + } + } + + if bestExactness < loose { + return errors.New("apijson: was not able to coerce type as union") + } + + if guardStrict(state, bestExactness != exact) { + return errors.New("apijson: was not able to coerce type as union strictly") + } + + return nil + } +} diff --git a/internal/apiquery/encoder.go b/internal/apiquery/encoder.go new file mode 100644 index 0000000..92c7cb8 --- /dev/null +++ b/internal/apiquery/encoder.go @@ -0,0 +1,415 @@ +package apiquery + +import ( + "encoding/json" + "fmt" + "reflect" + "strconv" + "strings" + "sync" + "time" + + "github.com/ScrapeGraphAI/scrapegraph-sdk/packages/param" +) + +var encoders sync.Map // map[reflect.Type]encoderFunc + +type encoder struct { + dateFormat string + root bool + settings QuerySettings +} + +type encoderFunc func(key string, value reflect.Value) ([]Pair, error) + +type encoderField struct { + tag parsedStructTag + fn encoderFunc + idx []int +} + +type encoderEntry struct { + reflect.Type + dateFormat string + root bool + settings QuerySettings +} + +type Pair struct { + key string + value string +} + +func (e *encoder) typeEncoder(t reflect.Type) encoderFunc { + entry := encoderEntry{ + Type: t, + dateFormat: e.dateFormat, + root: e.root, + settings: e.settings, + } + + if fi, ok := encoders.Load(entry); ok { + return fi.(encoderFunc) + } + + // To deal with recursive types, populate the map with an + // indirect func before we build it. This type waits on the + // real func (f) to be ready and then calls it. This indirect + // func is only used for recursive types. + var ( + wg sync.WaitGroup + f encoderFunc + ) + wg.Add(1) + fi, loaded := encoders.LoadOrStore(entry, encoderFunc(func(key string, v reflect.Value) ([]Pair, error) { + wg.Wait() + return f(key, v) + })) + if loaded { + return fi.(encoderFunc) + } + + // Compute the real encoder and replace the indirect func with it. + f = e.newTypeEncoder(t) + wg.Done() + encoders.Store(entry, f) + return f +} + +func marshalerEncoder(key string, value reflect.Value) ([]Pair, error) { + s, err := value.Interface().(json.Marshaler).MarshalJSON() + if err != nil { + return nil, fmt.Errorf("apiquery: json fallback marshal error %s", err) + } + return []Pair{{key, string(s)}}, nil +} + +func (e *encoder) newTypeEncoder(t reflect.Type) encoderFunc { + if t.ConvertibleTo(reflect.TypeOf(time.Time{})) { + return e.newTimeTypeEncoder(t) + } + + if t.Implements(reflect.TypeOf((*param.Optional)(nil)).Elem()) { + return e.newRichFieldTypeEncoder(t) + } + + if !e.root && t.Implements(reflect.TypeOf((*json.Marshaler)(nil)).Elem()) { + return marshalerEncoder + } + + e.root = false + switch t.Kind() { + case reflect.Pointer: + encoder := e.typeEncoder(t.Elem()) + return func(key string, value reflect.Value) (pairs []Pair, err error) { + if !value.IsValid() || value.IsNil() { + return + } + return encoder(key, value.Elem()) + } + case reflect.Struct: + return e.newStructTypeEncoder(t) + case reflect.Array: + fallthrough + case reflect.Slice: + return e.newArrayTypeEncoder(t) + case reflect.Map: + return e.newMapEncoder(t) + case reflect.Interface: + return e.newInterfaceEncoder() + default: + return e.newPrimitiveTypeEncoder(t) + } +} + +func (e *encoder) newStructTypeEncoder(t reflect.Type) encoderFunc { + if t.Implements(reflect.TypeOf((*param.Optional)(nil)).Elem()) { + return e.newRichFieldTypeEncoder(t) + } + + for i := 0; i < t.NumField(); i++ { + if t.Field(i).Type == paramUnionType && t.Field(i).Anonymous { + return e.newStructUnionTypeEncoder(t) + } + } + + encoderFields := []encoderField{} + + // This helper allows us to recursively collect field encoders into a flat + // array. The parameter `index` keeps track of the access patterns necessary + // to get to some field. + var collectEncoderFields func(r reflect.Type, index []int) + collectEncoderFields = func(r reflect.Type, index []int) { + for i := 0; i < r.NumField(); i++ { + idx := append(index, i) + field := t.FieldByIndex(idx) + if !field.IsExported() { + continue + } + // If this is an embedded struct, traverse one level deeper to extract + // the field and get their encoders as well. + if field.Anonymous { + collectEncoderFields(field.Type, idx) + continue + } + // If query tag is not present, then we skip, which is intentionally + // different behavior from the stdlib. + ptag, ok := parseQueryStructTag(field) + if !ok { + continue + } + + if (ptag.name == "-" || ptag.name == "") && !ptag.inline { + continue + } + + dateFormat, ok := parseFormatStructTag(field) + oldFormat := e.dateFormat + if ok { + switch dateFormat { + case "date-time": + e.dateFormat = time.RFC3339 + case "date": + e.dateFormat = "2006-01-02" + } + } + var encoderFn encoderFunc + if ptag.omitzero { + typeEncoderFn := e.typeEncoder(field.Type) + encoderFn = func(key string, value reflect.Value) ([]Pair, error) { + if value.IsZero() { + return nil, nil + } + return typeEncoderFn(key, value) + } + } else { + encoderFn = e.typeEncoder(field.Type) + } + encoderFields = append(encoderFields, encoderField{ptag, encoderFn, idx}) + e.dateFormat = oldFormat + } + } + collectEncoderFields(t, []int{}) + + return func(key string, value reflect.Value) (pairs []Pair, err error) { + for _, ef := range encoderFields { + var subkey string = e.renderKeyPath(key, ef.tag.name) + if ef.tag.inline { + subkey = key + } + + field := value.FieldByIndex(ef.idx) + subpairs, suberr := ef.fn(subkey, field) + if suberr != nil { + err = suberr + } + pairs = append(pairs, subpairs...) + } + return + } +} + +var paramUnionType = reflect.TypeOf((*param.APIUnion)(nil)).Elem() + +func (e *encoder) newStructUnionTypeEncoder(t reflect.Type) encoderFunc { + var fieldEncoders []encoderFunc + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + if field.Type == paramUnionType && field.Anonymous { + fieldEncoders = append(fieldEncoders, nil) + continue + } + fieldEncoders = append(fieldEncoders, e.typeEncoder(field.Type)) + } + + return func(key string, value reflect.Value) (pairs []Pair, err error) { + for i := 0; i < t.NumField(); i++ { + if value.Field(i).Type() == paramUnionType { + continue + } + if !value.Field(i).IsZero() { + return fieldEncoders[i](key, value.Field(i)) + } + } + return nil, fmt.Errorf("apiquery: union %s has no field set", t.String()) + } +} + +func (e *encoder) newMapEncoder(t reflect.Type) encoderFunc { + keyEncoder := e.typeEncoder(t.Key()) + elementEncoder := e.typeEncoder(t.Elem()) + return func(key string, value reflect.Value) (pairs []Pair, err error) { + iter := value.MapRange() + for iter.Next() { + encodedKey, err := keyEncoder("", iter.Key()) + if err != nil { + return nil, err + } + if len(encodedKey) != 1 { + return nil, fmt.Errorf("apiquery: unexpected number of parts for encoded map key, map may contain non-primitive") + } + subkey := encodedKey[0].value + keyPath := e.renderKeyPath(key, subkey) + subpairs, suberr := elementEncoder(keyPath, iter.Value()) + if suberr != nil { + err = suberr + } + pairs = append(pairs, subpairs...) + } + return + } +} + +func (e *encoder) renderKeyPath(key string, subkey string) string { + if len(key) == 0 { + return subkey + } + if e.settings.NestedFormat == NestedQueryFormatDots { + return fmt.Sprintf("%s.%s", key, subkey) + } + return fmt.Sprintf("%s[%s]", key, subkey) +} + +func (e *encoder) newArrayTypeEncoder(t reflect.Type) encoderFunc { + switch e.settings.ArrayFormat { + case ArrayQueryFormatComma: + innerEncoder := e.typeEncoder(t.Elem()) + return func(key string, v reflect.Value) ([]Pair, error) { + elements := []string{} + for i := 0; i < v.Len(); i++ { + innerPairs, err := innerEncoder("", v.Index(i)) + if err != nil { + return nil, err + } + for _, pair := range innerPairs { + elements = append(elements, pair.value) + } + } + if len(elements) == 0 { + return []Pair{}, nil + } + return []Pair{{key, strings.Join(elements, ",")}}, nil + } + case ArrayQueryFormatRepeat: + innerEncoder := e.typeEncoder(t.Elem()) + return func(key string, value reflect.Value) (pairs []Pair, err error) { + for i := 0; i < value.Len(); i++ { + subpairs, suberr := innerEncoder(key, value.Index(i)) + if suberr != nil { + err = suberr + } + pairs = append(pairs, subpairs...) + } + return + } + case ArrayQueryFormatIndices: + panic("The array indices format is not supported yet") + case ArrayQueryFormatBrackets: + innerEncoder := e.typeEncoder(t.Elem()) + return func(key string, value reflect.Value) (pairs []Pair, err error) { + pairs = []Pair{} + for i := 0; i < value.Len(); i++ { + subpairs, suberr := innerEncoder(key+"[]", value.Index(i)) + if suberr != nil { + err = suberr + } + pairs = append(pairs, subpairs...) + } + return + } + default: + panic(fmt.Sprintf("Unknown ArrayFormat value: %d", e.settings.ArrayFormat)) + } +} + +func (e *encoder) newPrimitiveTypeEncoder(t reflect.Type) encoderFunc { + switch t.Kind() { + case reflect.Pointer: + inner := t.Elem() + + innerEncoder := e.newPrimitiveTypeEncoder(inner) + return func(key string, v reflect.Value) ([]Pair, error) { + if !v.IsValid() || v.IsNil() { + return nil, nil + } + return innerEncoder(key, v.Elem()) + } + case reflect.String: + return func(key string, v reflect.Value) ([]Pair, error) { + return []Pair{{key, v.String()}}, nil + } + case reflect.Bool: + return func(key string, v reflect.Value) ([]Pair, error) { + if v.Bool() { + return []Pair{{key, "true"}}, nil + } + return []Pair{{key, "false"}}, nil + } + case reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64: + return func(key string, v reflect.Value) ([]Pair, error) { + return []Pair{{key, strconv.FormatInt(v.Int(), 10)}}, nil + } + case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return func(key string, v reflect.Value) ([]Pair, error) { + return []Pair{{key, strconv.FormatUint(v.Uint(), 10)}}, nil + } + case reflect.Float32, reflect.Float64: + return func(key string, v reflect.Value) ([]Pair, error) { + return []Pair{{key, strconv.FormatFloat(v.Float(), 'f', -1, 64)}}, nil + } + case reflect.Complex64, reflect.Complex128: + bitSize := 64 + if t.Kind() == reflect.Complex128 { + bitSize = 128 + } + return func(key string, v reflect.Value) ([]Pair, error) { + return []Pair{{key, strconv.FormatComplex(v.Complex(), 'f', -1, bitSize)}}, nil + } + default: + return func(key string, v reflect.Value) ([]Pair, error) { + return nil, nil + } + } +} + +func (e *encoder) newFieldTypeEncoder(t reflect.Type) encoderFunc { + f, _ := t.FieldByName("Value") + enc := e.typeEncoder(f.Type) + + return func(key string, value reflect.Value) ([]Pair, error) { + present := value.FieldByName("Present") + if !present.Bool() { + return nil, nil + } + null := value.FieldByName("Null") + if null.Bool() { + return nil, fmt.Errorf("apiquery: field cannot be null") + } + raw := value.FieldByName("Raw") + if !raw.IsNil() { + return e.typeEncoder(raw.Type())(key, raw) + } + return enc(key, value.FieldByName("Value")) + } +} + +func (e *encoder) newTimeTypeEncoder(_ reflect.Type) encoderFunc { + format := e.dateFormat + return func(key string, value reflect.Value) ([]Pair, error) { + return []Pair{{ + key, + value.Convert(reflect.TypeOf(time.Time{})).Interface().(time.Time).Format(format), + }}, nil + } +} + +func (e encoder) newInterfaceEncoder() encoderFunc { + return func(key string, value reflect.Value) ([]Pair, error) { + value = value.Elem() + if !value.IsValid() { + return nil, nil + } + return e.typeEncoder(value.Type())(key, value) + } + +} diff --git a/internal/apiquery/query.go b/internal/apiquery/query.go new file mode 100644 index 0000000..0f379fa --- /dev/null +++ b/internal/apiquery/query.go @@ -0,0 +1,55 @@ +package apiquery + +import ( + "net/url" + "reflect" + "time" +) + +func MarshalWithSettings(value any, settings QuerySettings) (url.Values, error) { + e := encoder{time.RFC3339, true, settings} + kv := url.Values{} + val := reflect.ValueOf(value) + if !val.IsValid() { + return nil, nil + } + typ := val.Type() + + pairs, err := e.typeEncoder(typ)("", val) + if err != nil { + return nil, err + } + for _, pair := range pairs { + kv.Add(pair.key, pair.value) + } + return kv, nil +} + +func Marshal(value any) (url.Values, error) { + return MarshalWithSettings(value, QuerySettings{}) +} + +type Queryer interface { + URLQuery() (url.Values, error) +} + +type QuerySettings struct { + NestedFormat NestedQueryFormat + ArrayFormat ArrayQueryFormat +} + +type NestedQueryFormat int + +const ( + NestedQueryFormatBrackets NestedQueryFormat = iota + NestedQueryFormatDots +) + +type ArrayQueryFormat int + +const ( + ArrayQueryFormatComma ArrayQueryFormat = iota + ArrayQueryFormatRepeat + ArrayQueryFormatIndices + ArrayQueryFormatBrackets +) diff --git a/internal/apiquery/query_test.go b/internal/apiquery/query_test.go new file mode 100644 index 0000000..cdfc30a --- /dev/null +++ b/internal/apiquery/query_test.go @@ -0,0 +1,435 @@ +package apiquery + +import ( + "github.com/ScrapeGraphAI/scrapegraph-sdk/packages/param" + "net/url" + "testing" + "time" +) + +func P[T any](v T) *T { return &v } + +type Primitives struct { + A bool `query:"a"` + B int `query:"b"` + C uint `query:"c"` + D float64 `query:"d"` + E float32 `query:"e"` + F []int `query:"f"` +} + +type PrimitivePointers struct { + A *bool `query:"a"` + B *int `query:"b"` + C *uint `query:"c"` + D *float64 `query:"d"` + E *float32 `query:"e"` + F *[]int `query:"f"` +} + +type Slices struct { + Slice []Primitives `query:"slices"` + Mixed []any `query:"mixed"` +} + +type DateTime struct { + Date time.Time `query:"date" format:"date"` + DateTime time.Time `query:"date-time" format:"date-time"` +} + +type AdditionalProperties struct { + A bool `query:"a"` + Extras map[string]any `query:"-,inline"` +} + +type Recursive struct { + Name string `query:"name"` + Child *Recursive `query:"child"` +} + +type UnknownStruct struct { + Unknown any `query:"unknown"` +} + +type UnionStruct struct { + Union Union `query:"union" format:"date"` +} + +type Union interface { + union() +} + +type UnionInteger int64 + +func (UnionInteger) union() {} + +type UnionString string + +func (UnionString) union() {} + +type UnionStructA struct { + Type string `query:"type"` + A string `query:"a"` + B string `query:"b"` +} + +func (UnionStructA) union() {} + +type UnionStructB struct { + Type string `query:"type"` + A string `query:"a"` +} + +func (UnionStructB) union() {} + +type UnionTime time.Time + +func (UnionTime) union() {} + +type DeeplyNested struct { + A DeeplyNested1 `query:"a"` +} + +type DeeplyNested1 struct { + B DeeplyNested2 `query:"b"` +} + +type DeeplyNested2 struct { + C DeeplyNested3 `query:"c"` +} + +type DeeplyNested3 struct { + D *string `query:"d"` +} + +type RichPrimitives struct { + A param.Opt[string] `query:"a"` +} + +type QueryOmitTest struct { + A param.Opt[string] `query:"a,omitzero"` + B string `query:"b,omitzero"` +} + +type NamedEnum string + +const NamedEnumFoo NamedEnum = "foo" + +type StructUnionWrapper struct { + Union StructUnion `query:"union"` +} + +type StructUnion struct { + OfInt param.Opt[int64] `query:",omitzero,inline"` + OfString param.Opt[string] `query:",omitzero,inline"` + OfEnum param.Opt[NamedEnum] `query:",omitzero,inline"` + OfA UnionStructA `query:",omitzero,inline"` + OfB UnionStructB `query:",omitzero,inline"` + param.APIUnion +} + +var tests = map[string]struct { + enc string + val any + settings QuerySettings +}{ + "primitives": { + "a=false&b=237628372683&c=654&d=9999.43&e=43.7599983215332&f=1,2,3,4", + Primitives{A: false, B: 237628372683, C: uint(654), D: 9999.43, E: 43.76, F: []int{1, 2, 3, 4}}, + QuerySettings{}, + }, + + "slices_brackets": { + `mixed[]=1&mixed[]=2.3&mixed[]=hello&slices[][a]=false&slices[][a]=false&slices[][b]=237628372683&slices[][b]=237628372683&slices[][c]=654&slices[][c]=654&slices[][d]=9999.43&slices[][d]=9999.43&slices[][e]=43.7599983215332&slices[][e]=43.7599983215332&slices[][f][]=1&slices[][f][]=2&slices[][f][]=3&slices[][f][]=4&slices[][f][]=1&slices[][f][]=2&slices[][f][]=3&slices[][f][]=4`, + Slices{ + Slice: []Primitives{ + {A: false, B: 237628372683, C: uint(654), D: 9999.43, E: 43.76, F: []int{1, 2, 3, 4}}, + {A: false, B: 237628372683, C: uint(654), D: 9999.43, E: 43.76, F: []int{1, 2, 3, 4}}, + }, + Mixed: []any{1, 2.3, "hello"}, + }, + QuerySettings{ArrayFormat: ArrayQueryFormatBrackets}, + }, + + "slices_comma": { + `mixed=1,2.3,hello`, + Slices{ + Mixed: []any{1, 2.3, "hello"}, + }, + QuerySettings{ArrayFormat: ArrayQueryFormatComma}, + }, + + "slices_repeat": { + `mixed=1&mixed=2.3&mixed=hello&slices[a]=false&slices[a]=false&slices[b]=237628372683&slices[b]=237628372683&slices[c]=654&slices[c]=654&slices[d]=9999.43&slices[d]=9999.43&slices[e]=43.7599983215332&slices[e]=43.7599983215332&slices[f]=1&slices[f]=2&slices[f]=3&slices[f]=4&slices[f]=1&slices[f]=2&slices[f]=3&slices[f]=4`, + Slices{ + Slice: []Primitives{ + {A: false, B: 237628372683, C: uint(654), D: 9999.43, E: 43.76, F: []int{1, 2, 3, 4}}, + {A: false, B: 237628372683, C: uint(654), D: 9999.43, E: 43.76, F: []int{1, 2, 3, 4}}, + }, + Mixed: []any{1, 2.3, "hello"}, + }, + QuerySettings{ArrayFormat: ArrayQueryFormatRepeat}, + }, + + "primitive_pointer_struct": { + "a=false&b=237628372683&c=654&d=9999.43&e=43.7599983215332&f=1,2,3,4,5", + PrimitivePointers{ + A: P(false), + B: P(237628372683), + C: P(uint(654)), + D: P(9999.43), + E: P(float32(43.76)), + F: &[]int{1, 2, 3, 4, 5}, + }, + QuerySettings{}, + }, + + "datetime_struct": { + `date=2006-01-02&date-time=2006-01-02T15:04:05Z`, + DateTime{ + Date: time.Date(2006, time.January, 2, 0, 0, 0, 0, time.UTC), + DateTime: time.Date(2006, time.January, 2, 15, 4, 5, 0, time.UTC), + }, + QuerySettings{}, + }, + + "additional_properties": { + `a=true&bar=value&foo=true`, + AdditionalProperties{ + A: true, + Extras: map[string]any{ + "bar": "value", + "foo": true, + }, + }, + QuerySettings{}, + }, + + "recursive_struct_brackets": { + `child[name]=Alex&name=Robert`, + Recursive{Name: "Robert", Child: &Recursive{Name: "Alex"}}, + QuerySettings{NestedFormat: NestedQueryFormatBrackets}, + }, + + "recursive_struct_dots": { + `child.name=Alex&name=Robert`, + Recursive{Name: "Robert", Child: &Recursive{Name: "Alex"}}, + QuerySettings{NestedFormat: NestedQueryFormatDots}, + }, + + "unknown_struct_number": { + `unknown=12`, + UnknownStruct{ + Unknown: 12., + }, + QuerySettings{}, + }, + + "unknown_struct_map_brackets": { + `unknown[foo]=bar`, + UnknownStruct{ + Unknown: map[string]any{ + "foo": "bar", + }, + }, + QuerySettings{NestedFormat: NestedQueryFormatBrackets}, + }, + + "unknown_struct_map_dots": { + `unknown.foo=bar`, + UnknownStruct{ + Unknown: map[string]any{ + "foo": "bar", + }, + }, + QuerySettings{NestedFormat: NestedQueryFormatDots}, + }, + + "struct_union_string": { + `union=hello`, + StructUnionWrapper{ + Union: StructUnion{OfString: param.NewOpt("hello")}, + }, + QuerySettings{}, + }, + + "union_string": { + `union=hello`, + UnionStruct{ + Union: UnionString("hello"), + }, + QuerySettings{}, + }, + + "struct_union_integer": { + `union=12`, + StructUnionWrapper{ + Union: StructUnion{OfInt: param.NewOpt[int64](12)}, + }, + QuerySettings{}, + }, + + "union_integer": { + `union=12`, + UnionStruct{ + Union: UnionInteger(12), + }, + QuerySettings{}, + }, + + "struct_union_enum": { + `union=foo`, + StructUnionWrapper{ + Union: StructUnion{OfEnum: param.NewOpt[NamedEnum](NamedEnumFoo)}, + }, + QuerySettings{}, + }, + + "struct_union_struct_discriminated_a": { + `union[a]=foo&union[b]=bar&union[type]=typeA`, + StructUnionWrapper{ + Union: StructUnion{OfA: UnionStructA{ + Type: "typeA", + A: "foo", + B: "bar", + }}, + }, + QuerySettings{}, + }, + + "union_struct_discriminated_a": { + `union[a]=foo&union[b]=bar&union[type]=typeA`, + UnionStruct{ + Union: UnionStructA{ + Type: "typeA", + A: "foo", + B: "bar", + }, + }, + QuerySettings{}, + }, + + "struct_union_struct_discriminated_b": { + `union[a]=foo&union[type]=typeB`, + StructUnionWrapper{ + Union: StructUnion{OfB: UnionStructB{ + Type: "typeB", + A: "foo", + }}, + }, + QuerySettings{}, + }, + + "union_struct_discriminated_b": { + `union[a]=foo&union[type]=typeB`, + UnionStruct{ + Union: UnionStructB{ + Type: "typeB", + A: "foo", + }, + }, + QuerySettings{}, + }, + + "union_struct_time": { + `union=2010-05-23`, + UnionStruct{ + Union: UnionTime(time.Date(2010, 05, 23, 0, 0, 0, 0, time.UTC)), + }, + QuerySettings{}, + }, + + "deeply_nested_brackets": { + `a[b][c][d]=hello`, + DeeplyNested{ + A: DeeplyNested1{ + B: DeeplyNested2{ + C: DeeplyNested3{ + D: P("hello"), + }, + }, + }, + }, + QuerySettings{NestedFormat: NestedQueryFormatBrackets}, + }, + + "deeply_nested_dots": { + `a.b.c.d=hello`, + DeeplyNested{ + A: DeeplyNested1{ + B: DeeplyNested2{ + C: DeeplyNested3{ + D: P("hello"), + }, + }, + }, + }, + QuerySettings{NestedFormat: NestedQueryFormatDots}, + }, + + "deeply_nested_brackets_empty": { + ``, + DeeplyNested{ + A: DeeplyNested1{ + B: DeeplyNested2{ + C: DeeplyNested3{ + D: nil, + }, + }, + }, + }, + QuerySettings{NestedFormat: NestedQueryFormatBrackets}, + }, + + "deeply_nested_dots_empty": { + ``, + DeeplyNested{ + A: DeeplyNested1{ + B: DeeplyNested2{ + C: DeeplyNested3{ + D: nil, + }, + }, + }, + }, + QuerySettings{NestedFormat: NestedQueryFormatDots}, + }, + + "rich_primitives": { + `a=hello`, + RichPrimitives{ + A: param.Opt[string]{Value: "hello"}, + }, + QuerySettings{}, + }, + + "rich_primitives_omit": { + ``, + QueryOmitTest{ + A: param.Opt[string]{}, + }, + QuerySettings{}, + }, + "query_omit": { + `a=hello`, + QueryOmitTest{ + A: param.Opt[string]{Value: "hello"}, + }, + QuerySettings{}, + }, +} + +func TestEncode(t *testing.T) { + for name, test := range tests { + t.Run(name, func(t *testing.T) { + values, err := MarshalWithSettings(test.val, test.settings) + if err != nil { + t.Fatalf("failed to marshal url %s", err) + } + str, _ := url.QueryUnescape(values.Encode()) + if str != test.enc { + t.Fatalf("expected %+#v to serialize to %s but got %s", test.val, test.enc, str) + } + }) + } +} diff --git a/internal/apiquery/richparam.go b/internal/apiquery/richparam.go new file mode 100644 index 0000000..0316506 --- /dev/null +++ b/internal/apiquery/richparam.go @@ -0,0 +1,19 @@ +package apiquery + +import ( + "github.com/ScrapeGraphAI/scrapegraph-sdk/packages/param" + "reflect" +) + +func (e *encoder) newRichFieldTypeEncoder(t reflect.Type) encoderFunc { + f, _ := t.FieldByName("Value") + enc := e.typeEncoder(f.Type) + return func(key string, value reflect.Value) ([]Pair, error) { + if opt, ok := value.Interface().(param.Optional); ok && opt.Valid() { + return enc(key, value.FieldByIndex(f.Index)) + } else if ok && param.IsNull(opt) { + return []Pair{{key, "null"}}, nil + } + return nil, nil + } +} diff --git a/internal/apiquery/tag.go b/internal/apiquery/tag.go new file mode 100644 index 0000000..772c40e --- /dev/null +++ b/internal/apiquery/tag.go @@ -0,0 +1,44 @@ +package apiquery + +import ( + "reflect" + "strings" +) + +const queryStructTag = "query" +const formatStructTag = "format" + +type parsedStructTag struct { + name string + omitempty bool + omitzero bool + inline bool +} + +func parseQueryStructTag(field reflect.StructField) (tag parsedStructTag, ok bool) { + raw, ok := field.Tag.Lookup(queryStructTag) + if !ok { + return + } + parts := strings.Split(raw, ",") + if len(parts) == 0 { + return tag, false + } + tag.name = parts[0] + for _, part := range parts[1:] { + switch part { + case "omitzero": + tag.omitzero = true + case "omitempty": + tag.omitempty = true + case "inline": + tag.inline = true + } + } + return +} + +func parseFormatStructTag(field reflect.StructField) (format string, ok bool) { + format, ok = field.Tag.Lookup(formatStructTag) + return +} diff --git a/internal/encoding/json/decode.go b/internal/encoding/json/decode.go new file mode 100644 index 0000000..4f1705c --- /dev/null +++ b/internal/encoding/json/decode.go @@ -0,0 +1,1324 @@ +// Vendored from Go 1.24.0-pre-release +// To find alterations, check package shims, and comments beginning in SHIM(). +// +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Represents JSON data structure using native Go types: booleans, floats, +// strings, arrays, and maps. + +package json + +import ( + "encoding" + "encoding/base64" + "fmt" + "github.com/ScrapeGraphAI/scrapegraph-sdk/internal/encoding/json/shims" + "reflect" + "strconv" + "strings" + "unicode" + "unicode/utf16" + "unicode/utf8" + _ "unsafe" // for linkname +) + +// Unmarshal parses the JSON-encoded data and stores the result +// in the value pointed to by v. If v is nil or not a pointer, +// Unmarshal returns an [InvalidUnmarshalError]. +// +// Unmarshal uses the inverse of the encodings that +// [Marshal] uses, allocating maps, slices, and pointers as necessary, +// with the following additional rules: +// +// To unmarshal JSON into a pointer, Unmarshal first handles the case of +// the JSON being the JSON literal null. In that case, Unmarshal sets +// the pointer to nil. Otherwise, Unmarshal unmarshals the JSON into +// the value pointed at by the pointer. If the pointer is nil, Unmarshal +// allocates a new value for it to point to. +// +// To unmarshal JSON into a value implementing [Unmarshaler], +// Unmarshal calls that value's [Unmarshaler.UnmarshalJSON] method, including +// when the input is a JSON null. +// Otherwise, if the value implements [encoding.TextUnmarshaler] +// and the input is a JSON quoted string, Unmarshal calls +// [encoding.TextUnmarshaler.UnmarshalText] with the unquoted form of the string. +// +// To unmarshal JSON into a struct, Unmarshal matches incoming object +// keys to the keys used by [Marshal] (either the struct field name or its tag), +// preferring an exact match but also accepting a case-insensitive match. By +// default, object keys which don't have a corresponding struct field are +// ignored (see [Decoder.DisallowUnknownFields] for an alternative). +// +// To unmarshal JSON into an interface value, +// Unmarshal stores one of these in the interface value: +// +// - bool, for JSON booleans +// - float64, for JSON numbers +// - string, for JSON strings +// - []any, for JSON arrays +// - map[string]any, for JSON objects +// - nil for JSON null +// +// To unmarshal a JSON array into a slice, Unmarshal resets the slice length +// to zero and then appends each element to the slice. +// As a special case, to unmarshal an empty JSON array into a slice, +// Unmarshal replaces the slice with a new empty slice. +// +// To unmarshal a JSON array into a Go array, Unmarshal decodes +// JSON array elements into corresponding Go array elements. +// If the Go array is smaller than the JSON array, +// the additional JSON array elements are discarded. +// If the JSON array is smaller than the Go array, +// the additional Go array elements are set to zero values. +// +// To unmarshal a JSON object into a map, Unmarshal first establishes a map to +// use. If the map is nil, Unmarshal allocates a new map. Otherwise Unmarshal +// reuses the existing map, keeping existing entries. Unmarshal then stores +// key-value pairs from the JSON object into the map. The map's key type must +// either be any string type, an integer, or implement [encoding.TextUnmarshaler]. +// +// If the JSON-encoded data contain a syntax error, Unmarshal returns a [SyntaxError]. +// +// If a JSON value is not appropriate for a given target type, +// or if a JSON number overflows the target type, Unmarshal +// skips that field and completes the unmarshaling as best it can. +// If no more serious errors are encountered, Unmarshal returns +// an [UnmarshalTypeError] describing the earliest such error. In any +// case, it's not guaranteed that all the remaining fields following +// the problematic one will be unmarshaled into the target object. +// +// The JSON null value unmarshals into an interface, map, pointer, or slice +// by setting that Go value to nil. Because null is often used in JSON to mean +// “not present,” unmarshaling a JSON null into any other Go type has no effect +// on the value and produces no error. +// +// When unmarshaling quoted strings, invalid UTF-8 or +// invalid UTF-16 surrogate pairs are not treated as an error. +// Instead, they are replaced by the Unicode replacement +// character U+FFFD. +func Unmarshal(data []byte, v any) error { + // Check for well-formedness. + // Avoids filling out half a data structure + // before discovering a JSON syntax error. + var d decodeState + err := checkValid(data, &d.scan) + if err != nil { + return err + } + + d.init(data) + return d.unmarshal(v) +} + +// Unmarshaler is the interface implemented by types +// that can unmarshal a JSON description of themselves. +// The input can be assumed to be a valid encoding of +// a JSON value. UnmarshalJSON must copy the JSON data +// if it wishes to retain the data after returning. +// +// By convention, to approximate the behavior of [Unmarshal] itself, +// Unmarshalers implement UnmarshalJSON([]byte("null")) as a no-op. +type Unmarshaler interface { + UnmarshalJSON([]byte) error +} + +// An UnmarshalTypeError describes a JSON value that was +// not appropriate for a value of a specific Go type. +type UnmarshalTypeError struct { + Value string // description of JSON value - "bool", "array", "number -5" + Type reflect.Type // type of Go value it could not be assigned to + Offset int64 // error occurred after reading Offset bytes + Struct string // name of the struct type containing the field + Field string // the full path from root node to the field, include embedded struct +} + +func (e *UnmarshalTypeError) Error() string { + if e.Struct != "" || e.Field != "" { + return "json: cannot unmarshal " + e.Value + " into Go struct field " + e.Struct + "." + e.Field + " of type " + e.Type.String() + } + return "json: cannot unmarshal " + e.Value + " into Go value of type " + e.Type.String() +} + +// An UnmarshalFieldError describes a JSON object key that +// led to an unexported (and therefore unwritable) struct field. +// +// Deprecated: No longer used; kept for compatibility. +type UnmarshalFieldError struct { + Key string + Type reflect.Type + Field reflect.StructField +} + +func (e *UnmarshalFieldError) Error() string { + return "json: cannot unmarshal object key " + strconv.Quote(e.Key) + " into unexported field " + e.Field.Name + " of type " + e.Type.String() +} + +// An InvalidUnmarshalError describes an invalid argument passed to [Unmarshal]. +// (The argument to [Unmarshal] must be a non-nil pointer.) +type InvalidUnmarshalError struct { + Type reflect.Type +} + +func (e *InvalidUnmarshalError) Error() string { + if e.Type == nil { + return "json: Unmarshal(nil)" + } + + if e.Type.Kind() != reflect.Pointer { + return "json: Unmarshal(non-pointer " + e.Type.String() + ")" + } + return "json: Unmarshal(nil " + e.Type.String() + ")" +} + +func (d *decodeState) unmarshal(v any) error { + rv := reflect.ValueOf(v) + if rv.Kind() != reflect.Pointer || rv.IsNil() { + return &InvalidUnmarshalError{reflect.TypeOf(v)} + } + + d.scan.reset() + d.scanWhile(scanSkipSpace) + // We decode rv not rv.Elem because the Unmarshaler interface + // test must be applied at the top level of the value. + err := d.value(rv) + if err != nil { + return d.addErrorContext(err) + } + return d.savedError +} + +// A Number represents a JSON number literal. +type Number string + +// String returns the literal text of the number. +func (n Number) String() string { return string(n) } + +// Float64 returns the number as a float64. +func (n Number) Float64() (float64, error) { + return strconv.ParseFloat(string(n), 64) +} + +// Int64 returns the number as an int64. +func (n Number) Int64() (int64, error) { + return strconv.ParseInt(string(n), 10, 64) +} + +// An errorContext provides context for type errors during decoding. +type errorContext struct { + Struct reflect.Type + FieldStack []string +} + +// decodeState represents the state while decoding a JSON value. +type decodeState struct { + data []byte + off int // next read offset in data + opcode int // last read result + scan scanner + errorContext *errorContext + savedError error + useNumber bool + disallowUnknownFields bool +} + +// readIndex returns the position of the last byte read. +func (d *decodeState) readIndex() int { + return d.off - 1 +} + +// phasePanicMsg is used as a panic message when we end up with something that +// shouldn't happen. It can indicate a bug in the JSON decoder, or that +// something is editing the data slice while the decoder executes. +const phasePanicMsg = "JSON decoder out of sync - data changing underfoot?" + +func (d *decodeState) init(data []byte) *decodeState { + d.data = data + d.off = 0 + d.savedError = nil + if d.errorContext != nil { + d.errorContext.Struct = nil + // Reuse the allocated space for the FieldStack slice. + d.errorContext.FieldStack = d.errorContext.FieldStack[:0] + } + return d +} + +// saveError saves the first err it is called with, +// for reporting at the end of the unmarshal. +func (d *decodeState) saveError(err error) { + if d.savedError == nil { + d.savedError = d.addErrorContext(err) + } +} + +// addErrorContext returns a new error enhanced with information from d.errorContext +func (d *decodeState) addErrorContext(err error) error { + if d.errorContext != nil && (d.errorContext.Struct != nil || len(d.errorContext.FieldStack) > 0) { + switch err := err.(type) { + case *UnmarshalTypeError: + err.Struct = d.errorContext.Struct.Name() + fieldStack := d.errorContext.FieldStack + if err.Field != "" { + fieldStack = append(fieldStack, err.Field) + } + err.Field = strings.Join(fieldStack, ".") + } + } + return err +} + +// skip scans to the end of what was started. +func (d *decodeState) skip() { + s, data, i := &d.scan, d.data, d.off + depth := len(s.parseState) + for { + op := s.step(s, data[i]) + i++ + if len(s.parseState) < depth { + d.off = i + d.opcode = op + return + } + } +} + +// scanNext processes the byte at d.data[d.off]. +func (d *decodeState) scanNext() { + if d.off < len(d.data) { + d.opcode = d.scan.step(&d.scan, d.data[d.off]) + d.off++ + } else { + d.opcode = d.scan.eof() + d.off = len(d.data) + 1 // mark processed EOF with len+1 + } +} + +// scanWhile processes bytes in d.data[d.off:] until it +// receives a scan code not equal to op. +func (d *decodeState) scanWhile(op int) { + s, data, i := &d.scan, d.data, d.off + for i < len(data) { + newOp := s.step(s, data[i]) + i++ + if newOp != op { + d.opcode = newOp + d.off = i + return + } + } + + d.off = len(data) + 1 // mark processed EOF with len+1 + d.opcode = d.scan.eof() +} + +// rescanLiteral is similar to scanWhile(scanContinue), but it specialises the +// common case where we're decoding a literal. The decoder scans the input +// twice, once for syntax errors and to check the length of the value, and the +// second to perform the decoding. +// +// Only in the second step do we use decodeState to tokenize literals, so we +// know there aren't any syntax errors. We can take advantage of that knowledge, +// and scan a literal's bytes much more quickly. +func (d *decodeState) rescanLiteral() { + data, i := d.data, d.off +Switch: + switch data[i-1] { + case '"': // string + for ; i < len(data); i++ { + switch data[i] { + case '\\': + i++ // escaped char + case '"': + i++ // tokenize the closing quote too + break Switch + } + } + case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-': // number + for ; i < len(data); i++ { + switch data[i] { + case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', + '.', 'e', 'E', '+', '-': + default: + break Switch + } + } + case 't': // true + i += len("rue") + case 'f': // false + i += len("alse") + case 'n': // null + i += len("ull") + } + if i < len(data) { + d.opcode = stateEndValue(&d.scan, data[i]) + } else { + d.opcode = scanEnd + } + d.off = i + 1 +} + +// value consumes a JSON value from d.data[d.off-1:], decoding into v, and +// reads the following byte ahead. If v is invalid, the value is discarded. +// The first byte of the value has been read already. +func (d *decodeState) value(v reflect.Value) error { + switch d.opcode { + default: + panic(phasePanicMsg) + + case scanBeginArray: + if v.IsValid() { + if err := d.array(v); err != nil { + return err + } + } else { + d.skip() + } + d.scanNext() + + case scanBeginObject: + if v.IsValid() { + if err := d.object(v); err != nil { + return err + } + } else { + d.skip() + } + d.scanNext() + + case scanBeginLiteral: + // All bytes inside literal return scanContinue op code. + start := d.readIndex() + d.rescanLiteral() + + if v.IsValid() { + if err := d.literalStore(d.data[start:d.readIndex()], v, false); err != nil { + return err + } + } + } + return nil +} + +type unquotedValue struct{} + +// valueQuoted is like value but decodes a +// quoted string literal or literal null into an interface value. +// If it finds anything other than a quoted string literal or null, +// valueQuoted returns unquotedValue{}. +func (d *decodeState) valueQuoted() any { + switch d.opcode { + default: + panic(phasePanicMsg) + + case scanBeginArray, scanBeginObject: + d.skip() + d.scanNext() + + case scanBeginLiteral: + v := d.literalInterface() + switch v.(type) { + case nil, string: + return v + } + } + return unquotedValue{} +} + +// indirect walks down v allocating pointers as needed, +// until it gets to a non-pointer. +// If it encounters an Unmarshaler, indirect stops and returns that. +// If decodingNull is true, indirect stops at the first settable pointer so it +// can be set to nil. +func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, encoding.TextUnmarshaler, reflect.Value) { + // Issue #24153 indicates that it is generally not a guaranteed property + // that you may round-trip a reflect.Value by calling Value.Addr().Elem() + // and expect the value to still be settable for values derived from + // unexported embedded struct fields. + // + // The logic below effectively does this when it first addresses the value + // (to satisfy possible pointer methods) and continues to dereference + // subsequent pointers as necessary. + // + // After the first round-trip, we set v back to the original value to + // preserve the original RW flags contained in reflect.Value. + v0 := v + haveAddr := false + + // If v is a named type and is addressable, + // start with its address, so that if the type has pointer methods, + // we find them. + if v.Kind() != reflect.Pointer && v.Type().Name() != "" && v.CanAddr() { + haveAddr = true + v = v.Addr() + } + for { + // Load value from interface, but only if the result will be + // usefully addressable. + if v.Kind() == reflect.Interface && !v.IsNil() { + e := v.Elem() + if e.Kind() == reflect.Pointer && !e.IsNil() && (!decodingNull || e.Elem().Kind() == reflect.Pointer) { + haveAddr = false + v = e + continue + } + } + + if v.Kind() != reflect.Pointer { + break + } + + if decodingNull && v.CanSet() { + break + } + + // Prevent infinite loop if v is an interface pointing to its own address: + // var v any + // v = &v + if v.Elem().Kind() == reflect.Interface && v.Elem().Elem().Equal(v) { + v = v.Elem() + break + } + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + if v.Type().NumMethod() > 0 && v.CanInterface() { + if u, ok := v.Interface().(Unmarshaler); ok { + return u, nil, reflect.Value{} + } + if !decodingNull { + if u, ok := v.Interface().(encoding.TextUnmarshaler); ok { + return nil, u, reflect.Value{} + } + } + } + + if haveAddr { + v = v0 // restore original value after round-trip Value.Addr().Elem() + haveAddr = false + } else { + v = v.Elem() + } + } + return nil, nil, v +} + +// array consumes an array from d.data[d.off-1:], decoding into v. +// The first byte of the array ('[') has been read already. +func (d *decodeState) array(v reflect.Value) error { + // Check for unmarshaler. + u, ut, pv := indirect(v, false) + if u != nil { + start := d.readIndex() + d.skip() + return u.UnmarshalJSON(d.data[start:d.off]) + } + if ut != nil { + d.saveError(&UnmarshalTypeError{Value: "array", Type: v.Type(), Offset: int64(d.off)}) + d.skip() + return nil + } + v = pv + + // Check type of target. + switch v.Kind() { + case reflect.Interface: + if v.NumMethod() == 0 { + // Decoding into nil interface? Switch to non-reflect code. + ai := d.arrayInterface() + v.Set(reflect.ValueOf(ai)) + return nil + } + // Otherwise it's invalid. + fallthrough + default: + d.saveError(&UnmarshalTypeError{Value: "array", Type: v.Type(), Offset: int64(d.off)}) + d.skip() + return nil + case reflect.Array, reflect.Slice: + break + } + + i := 0 + for { + // Look ahead for ] - can only happen on first iteration. + d.scanWhile(scanSkipSpace) + if d.opcode == scanEndArray { + break + } + + // Expand slice length, growing the slice if necessary. + if v.Kind() == reflect.Slice { + if i >= v.Cap() { + v.Grow(1) + } + if i >= v.Len() { + v.SetLen(i + 1) + } + } + + if i < v.Len() { + // Decode into element. + if err := d.value(v.Index(i)); err != nil { + return err + } + } else { + // Ran out of fixed array: skip. + if err := d.value(reflect.Value{}); err != nil { + return err + } + } + i++ + + // Next token must be , or ]. + if d.opcode == scanSkipSpace { + d.scanWhile(scanSkipSpace) + } + if d.opcode == scanEndArray { + break + } + if d.opcode != scanArrayValue { + panic(phasePanicMsg) + } + } + + if i < v.Len() { + if v.Kind() == reflect.Array { + for ; i < v.Len(); i++ { + v.Index(i).SetZero() // zero remainder of array + } + } else { + v.SetLen(i) // truncate the slice + } + } + if i == 0 && v.Kind() == reflect.Slice { + v.Set(reflect.MakeSlice(v.Type(), 0, 0)) + } + return nil +} + +var nullLiteral = []byte("null") + +// SHIM(reflect): reflect.TypeFor[T]() reflect.T +var textUnmarshalerType = shims.TypeFor[encoding.TextUnmarshaler]() + +// object consumes an object from d.data[d.off-1:], decoding into v. +// The first byte ('{') of the object has been read already. +func (d *decodeState) object(v reflect.Value) error { + // Check for unmarshaler. + u, ut, pv := indirect(v, false) + if u != nil { + start := d.readIndex() + d.skip() + return u.UnmarshalJSON(d.data[start:d.off]) + } + if ut != nil { + d.saveError(&UnmarshalTypeError{Value: "object", Type: v.Type(), Offset: int64(d.off)}) + d.skip() + return nil + } + v = pv + t := v.Type() + + // Decoding into nil interface? Switch to non-reflect code. + if v.Kind() == reflect.Interface && v.NumMethod() == 0 { + oi := d.objectInterface() + v.Set(reflect.ValueOf(oi)) + return nil + } + + var fields structFields + + // Check type of target: + // struct or + // map[T1]T2 where T1 is string, an integer type, + // or an encoding.TextUnmarshaler + switch v.Kind() { + case reflect.Map: + // Map key must either have string kind, have an integer kind, + // or be an encoding.TextUnmarshaler. + switch t.Key().Kind() { + case reflect.String, + reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + default: + if !reflect.PointerTo(t.Key()).Implements(textUnmarshalerType) { + d.saveError(&UnmarshalTypeError{Value: "object", Type: t, Offset: int64(d.off)}) + d.skip() + return nil + } + } + if v.IsNil() { + v.Set(reflect.MakeMap(t)) + } + case reflect.Struct: + fields = cachedTypeFields(t) + // ok + default: + d.saveError(&UnmarshalTypeError{Value: "object", Type: t, Offset: int64(d.off)}) + d.skip() + return nil + } + + var mapElem reflect.Value + var origErrorContext errorContext + if d.errorContext != nil { + origErrorContext = *d.errorContext + } + + for { + // Read opening " of string key or closing }. + d.scanWhile(scanSkipSpace) + if d.opcode == scanEndObject { + // closing } - can only happen on first iteration. + break + } + if d.opcode != scanBeginLiteral { + panic(phasePanicMsg) + } + + // Read key. + start := d.readIndex() + d.rescanLiteral() + item := d.data[start:d.readIndex()] + key, ok := unquoteBytes(item) + if !ok { + panic(phasePanicMsg) + } + + // Figure out field corresponding to key. + var subv reflect.Value + destring := false // whether the value is wrapped in a string to be decoded first + + if v.Kind() == reflect.Map { + elemType := t.Elem() + if !mapElem.IsValid() { + mapElem = reflect.New(elemType).Elem() + } else { + mapElem.SetZero() + } + subv = mapElem + } else { + f := fields.byExactName[string(key)] + if f == nil { + f = fields.byFoldedName[string(foldName(key))] + } + if f != nil { + subv = v + destring = f.quoted + if d.errorContext == nil { + d.errorContext = new(errorContext) + } + for i, ind := range f.index { + if subv.Kind() == reflect.Pointer { + if subv.IsNil() { + // If a struct embeds a pointer to an unexported type, + // it is not possible to set a newly allocated value + // since the field is unexported. + // + // See https://golang.org/issue/21357 + if !subv.CanSet() { + d.saveError(fmt.Errorf("json: cannot set embedded pointer to unexported struct: %v", subv.Type().Elem())) + // Invalidate subv to ensure d.value(subv) skips over + // the JSON value without assigning it to subv. + subv = reflect.Value{} + destring = false + break + } + subv.Set(reflect.New(subv.Type().Elem())) + } + subv = subv.Elem() + } + if i < len(f.index)-1 { + d.errorContext.FieldStack = append( + d.errorContext.FieldStack, + subv.Type().Field(ind).Name, + ) + } + subv = subv.Field(ind) + } + d.errorContext.Struct = t + d.errorContext.FieldStack = append(d.errorContext.FieldStack, f.name) + } else if d.disallowUnknownFields { + d.saveError(fmt.Errorf("json: unknown field %q", key)) + } + } + + // Read : before value. + if d.opcode == scanSkipSpace { + d.scanWhile(scanSkipSpace) + } + if d.opcode != scanObjectKey { + panic(phasePanicMsg) + } + d.scanWhile(scanSkipSpace) + + if destring { + switch qv := d.valueQuoted().(type) { + case nil: + if err := d.literalStore(nullLiteral, subv, false); err != nil { + return err + } + case string: + if err := d.literalStore([]byte(qv), subv, true); err != nil { + return err + } + default: + d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal unquoted value into %v", subv.Type())) + } + } else { + if err := d.value(subv); err != nil { + return err + } + } + + // Write value back to map; + // if using struct, subv points into struct already. + if v.Kind() == reflect.Map { + kt := t.Key() + var kv reflect.Value + if reflect.PointerTo(kt).Implements(textUnmarshalerType) { + kv = reflect.New(kt) + if err := d.literalStore(item, kv, true); err != nil { + return err + } + kv = kv.Elem() + } else { + switch kt.Kind() { + case reflect.String: + kv = reflect.New(kt).Elem() + kv.SetString(string(key)) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + s := string(key) + n, err := strconv.ParseInt(s, 10, 64) + // SHIM(reflect): reflect.Type.OverflowInt(int64) bool + okt := shims.OverflowableType{Type: kt} + if err != nil || okt.OverflowInt(n) { + d.saveError(&UnmarshalTypeError{Value: "number " + s, Type: kt, Offset: int64(start + 1)}) + break + } + kv = reflect.New(kt).Elem() + kv.SetInt(n) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + s := string(key) + n, err := strconv.ParseUint(s, 10, 64) + // SHIM(reflect): reflect.Type.OverflowUint(uint64) bool + okt := shims.OverflowableType{Type: kt} + if err != nil || okt.OverflowUint(n) { + d.saveError(&UnmarshalTypeError{Value: "number " + s, Type: kt, Offset: int64(start + 1)}) + break + } + kv = reflect.New(kt).Elem() + kv.SetUint(n) + default: + panic("json: Unexpected key type") // should never occur + } + } + if kv.IsValid() { + v.SetMapIndex(kv, subv) + } + } + + // Next token must be , or }. + if d.opcode == scanSkipSpace { + d.scanWhile(scanSkipSpace) + } + if d.errorContext != nil { + // Reset errorContext to its original state. + // Keep the same underlying array for FieldStack, to reuse the + // space and avoid unnecessary allocs. + d.errorContext.FieldStack = d.errorContext.FieldStack[:len(origErrorContext.FieldStack)] + d.errorContext.Struct = origErrorContext.Struct + } + if d.opcode == scanEndObject { + break + } + if d.opcode != scanObjectValue { + panic(phasePanicMsg) + } + } + return nil +} + +// convertNumber converts the number literal s to a float64 or a Number +// depending on the setting of d.useNumber. +func (d *decodeState) convertNumber(s string) (any, error) { + if d.useNumber { + return Number(s), nil + } + f, err := strconv.ParseFloat(s, 64) + if err != nil { + // SHIM(reflect): reflect.TypeFor[T]() reflect.Type + return nil, &UnmarshalTypeError{Value: "number " + s, Type: shims.TypeFor[float64](), Offset: int64(d.off)} + } + return f, nil +} + +// SHIM(reflect): TypeFor[T]() reflect.Type +var numberType = shims.TypeFor[Number]() + +// literalStore decodes a literal stored in item into v. +// +// fromQuoted indicates whether this literal came from unwrapping a +// string from the ",string" struct tag option. this is used only to +// produce more helpful error messages. +func (d *decodeState) literalStore(item []byte, v reflect.Value, fromQuoted bool) error { + // Check for unmarshaler. + if len(item) == 0 { + // Empty string given. + d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type())) + return nil + } + isNull := item[0] == 'n' // null + u, ut, pv := indirect(v, isNull) + if u != nil { + return u.UnmarshalJSON(item) + } + if ut != nil { + if item[0] != '"' { + if fromQuoted { + d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type())) + return nil + } + val := "number" + switch item[0] { + case 'n': + val = "null" + case 't', 'f': + val = "bool" + } + d.saveError(&UnmarshalTypeError{Value: val, Type: v.Type(), Offset: int64(d.readIndex())}) + return nil + } + s, ok := unquoteBytes(item) + if !ok { + if fromQuoted { + return fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type()) + } + panic(phasePanicMsg) + } + return ut.UnmarshalText(s) + } + + v = pv + + switch c := item[0]; c { + case 'n': // null + // The main parser checks that only true and false can reach here, + // but if this was a quoted string input, it could be anything. + if fromQuoted && string(item) != "null" { + d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type())) + break + } + switch v.Kind() { + case reflect.Interface, reflect.Pointer, reflect.Map, reflect.Slice: + v.SetZero() + // otherwise, ignore null for primitives/string + } + case 't', 'f': // true, false + value := item[0] == 't' + // The main parser checks that only true and false can reach here, + // but if this was a quoted string input, it could be anything. + if fromQuoted && string(item) != "true" && string(item) != "false" { + d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type())) + break + } + switch v.Kind() { + default: + if fromQuoted { + d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type())) + } else { + d.saveError(&UnmarshalTypeError{Value: "bool", Type: v.Type(), Offset: int64(d.readIndex())}) + } + case reflect.Bool: + v.SetBool(value) + case reflect.Interface: + if v.NumMethod() == 0 { + v.Set(reflect.ValueOf(value)) + } else { + d.saveError(&UnmarshalTypeError{Value: "bool", Type: v.Type(), Offset: int64(d.readIndex())}) + } + } + + case '"': // string + s, ok := unquoteBytes(item) + if !ok { + if fromQuoted { + return fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type()) + } + panic(phasePanicMsg) + } + switch v.Kind() { + default: + d.saveError(&UnmarshalTypeError{Value: "string", Type: v.Type(), Offset: int64(d.readIndex())}) + case reflect.Slice: + if v.Type().Elem().Kind() != reflect.Uint8 { + d.saveError(&UnmarshalTypeError{Value: "string", Type: v.Type(), Offset: int64(d.readIndex())}) + break + } + b := make([]byte, base64.StdEncoding.DecodedLen(len(s))) + n, err := base64.StdEncoding.Decode(b, s) + if err != nil { + d.saveError(err) + break + } + v.SetBytes(b[:n]) + case reflect.String: + t := string(s) + if v.Type() == numberType && !isValidNumber(t) { + return fmt.Errorf("json: invalid number literal, trying to unmarshal %q into Number", item) + } + v.SetString(t) + case reflect.Interface: + if v.NumMethod() == 0 { + v.Set(reflect.ValueOf(string(s))) + } else { + d.saveError(&UnmarshalTypeError{Value: "string", Type: v.Type(), Offset: int64(d.readIndex())}) + } + } + + default: // number + if c != '-' && (c < '0' || c > '9') { + if fromQuoted { + return fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type()) + } + panic(phasePanicMsg) + } + switch v.Kind() { + default: + if v.Kind() == reflect.String && v.Type() == numberType { + // s must be a valid number, because it's + // already been tokenized. + v.SetString(string(item)) + break + } + if fromQuoted { + return fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type()) + } + d.saveError(&UnmarshalTypeError{Value: "number", Type: v.Type(), Offset: int64(d.readIndex())}) + case reflect.Interface: + n, err := d.convertNumber(string(item)) + if err != nil { + d.saveError(err) + break + } + if v.NumMethod() != 0 { + d.saveError(&UnmarshalTypeError{Value: "number", Type: v.Type(), Offset: int64(d.readIndex())}) + break + } + v.Set(reflect.ValueOf(n)) + + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + n, err := strconv.ParseInt(string(item), 10, 64) + if err != nil || v.OverflowInt(n) { + d.saveError(&UnmarshalTypeError{Value: "number " + string(item), Type: v.Type(), Offset: int64(d.readIndex())}) + break + } + v.SetInt(n) + + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + n, err := strconv.ParseUint(string(item), 10, 64) + if err != nil || v.OverflowUint(n) { + d.saveError(&UnmarshalTypeError{Value: "number " + string(item), Type: v.Type(), Offset: int64(d.readIndex())}) + break + } + v.SetUint(n) + + case reflect.Float32, reflect.Float64: + n, err := strconv.ParseFloat(string(item), v.Type().Bits()) + if err != nil || v.OverflowFloat(n) { + d.saveError(&UnmarshalTypeError{Value: "number " + string(item), Type: v.Type(), Offset: int64(d.readIndex())}) + break + } + v.SetFloat(n) + } + } + return nil +} + +// The xxxInterface routines build up a value to be stored +// in an empty interface. They are not strictly necessary, +// but they avoid the weight of reflection in this common case. + +// valueInterface is like value but returns any. +func (d *decodeState) valueInterface() (val any) { + switch d.opcode { + default: + panic(phasePanicMsg) + case scanBeginArray: + val = d.arrayInterface() + d.scanNext() + case scanBeginObject: + val = d.objectInterface() + d.scanNext() + case scanBeginLiteral: + val = d.literalInterface() + } + return +} + +// arrayInterface is like array but returns []any. +func (d *decodeState) arrayInterface() []any { + var v = make([]any, 0) + for { + // Look ahead for ] - can only happen on first iteration. + d.scanWhile(scanSkipSpace) + if d.opcode == scanEndArray { + break + } + + v = append(v, d.valueInterface()) + + // Next token must be , or ]. + if d.opcode == scanSkipSpace { + d.scanWhile(scanSkipSpace) + } + if d.opcode == scanEndArray { + break + } + if d.opcode != scanArrayValue { + panic(phasePanicMsg) + } + } + return v +} + +// objectInterface is like object but returns map[string]any. +func (d *decodeState) objectInterface() map[string]any { + m := make(map[string]any) + for { + // Read opening " of string key or closing }. + d.scanWhile(scanSkipSpace) + if d.opcode == scanEndObject { + // closing } - can only happen on first iteration. + break + } + if d.opcode != scanBeginLiteral { + panic(phasePanicMsg) + } + + // Read string key. + start := d.readIndex() + d.rescanLiteral() + item := d.data[start:d.readIndex()] + key, ok := unquote(item) + if !ok { + panic(phasePanicMsg) + } + + // Read : before value. + if d.opcode == scanSkipSpace { + d.scanWhile(scanSkipSpace) + } + if d.opcode != scanObjectKey { + panic(phasePanicMsg) + } + d.scanWhile(scanSkipSpace) + + // Read value. + m[key] = d.valueInterface() + + // Next token must be , or }. + if d.opcode == scanSkipSpace { + d.scanWhile(scanSkipSpace) + } + if d.opcode == scanEndObject { + break + } + if d.opcode != scanObjectValue { + panic(phasePanicMsg) + } + } + return m +} + +// literalInterface consumes and returns a literal from d.data[d.off-1:] and +// it reads the following byte ahead. The first byte of the literal has been +// read already (that's how the caller knows it's a literal). +func (d *decodeState) literalInterface() any { + // All bytes inside literal return scanContinue op code. + start := d.readIndex() + d.rescanLiteral() + + item := d.data[start:d.readIndex()] + + switch c := item[0]; c { + case 'n': // null + return nil + + case 't', 'f': // true, false + return c == 't' + + case '"': // string + s, ok := unquote(item) + if !ok { + panic(phasePanicMsg) + } + return s + + default: // number + if c != '-' && (c < '0' || c > '9') { + panic(phasePanicMsg) + } + n, err := d.convertNumber(string(item)) + if err != nil { + d.saveError(err) + } + return n + } +} + +// getu4 decodes \uXXXX from the beginning of s, returning the hex value, +// or it returns -1. +func getu4(s []byte) rune { + if len(s) < 6 || s[0] != '\\' || s[1] != 'u' { + return -1 + } + var r rune + for _, c := range s[2:6] { + switch { + case '0' <= c && c <= '9': + c = c - '0' + case 'a' <= c && c <= 'f': + c = c - 'a' + 10 + case 'A' <= c && c <= 'F': + c = c - 'A' + 10 + default: + return -1 + } + r = r*16 + rune(c) + } + return r +} + +// unquote converts a quoted JSON string literal s into an actual string t. +// The rules are different than for Go, so cannot use strconv.Unquote. +func unquote(s []byte) (t string, ok bool) { + s, ok = unquoteBytes(s) + t = string(s) + return +} + +// unquoteBytes should be an internal detail, +// but widely used packages access it using linkname. +// Notable members of the hall of shame include: +// - github.com/bytedance/sonic +// +// Do not remove or change the type signature. +// See go.dev/issue/67401. +// +//go:linkname unquoteBytes +func unquoteBytes(s []byte) (t []byte, ok bool) { + if len(s) < 2 || s[0] != '"' || s[len(s)-1] != '"' { + return + } + s = s[1 : len(s)-1] + + // Check for unusual characters. If there are none, + // then no unquoting is needed, so return a slice of the + // original bytes. + r := 0 + for r < len(s) { + c := s[r] + if c == '\\' || c == '"' || c < ' ' { + break + } + if c < utf8.RuneSelf { + r++ + continue + } + rr, size := utf8.DecodeRune(s[r:]) + if rr == utf8.RuneError && size == 1 { + break + } + r += size + } + if r == len(s) { + return s, true + } + + b := make([]byte, len(s)+2*utf8.UTFMax) + w := copy(b, s[0:r]) + for r < len(s) { + // Out of room? Can only happen if s is full of + // malformed UTF-8 and we're replacing each + // byte with RuneError. + if w >= len(b)-2*utf8.UTFMax { + nb := make([]byte, (len(b)+utf8.UTFMax)*2) + copy(nb, b[0:w]) + b = nb + } + switch c := s[r]; { + case c == '\\': + r++ + if r >= len(s) { + return + } + switch s[r] { + default: + return + case '"', '\\', '/', '\'': + b[w] = s[r] + r++ + w++ + case 'b': + b[w] = '\b' + r++ + w++ + case 'f': + b[w] = '\f' + r++ + w++ + case 'n': + b[w] = '\n' + r++ + w++ + case 'r': + b[w] = '\r' + r++ + w++ + case 't': + b[w] = '\t' + r++ + w++ + case 'u': + r-- + rr := getu4(s[r:]) + if rr < 0 { + return + } + r += 6 + if utf16.IsSurrogate(rr) { + rr1 := getu4(s[r:]) + if dec := utf16.DecodeRune(rr, rr1); dec != unicode.ReplacementChar { + // A valid pair; consume. + r += 6 + w += utf8.EncodeRune(b[w:], dec) + break + } + // Invalid surrogate; fall back to replacement rune. + rr = unicode.ReplacementChar + } + w += utf8.EncodeRune(b[w:], rr) + } + + // Quote, control characters are invalid. + case c == '"', c < ' ': + return + + // ASCII + case c < utf8.RuneSelf: + b[w] = c + r++ + w++ + + // Coerce to well-formed UTF-8. + default: + rr, size := utf8.DecodeRune(s[r:]) + r += size + w += utf8.EncodeRune(b[w:], rr) + } + } + return b[0:w], true +} diff --git a/internal/encoding/json/encode.go b/internal/encoding/json/encode.go new file mode 100644 index 0000000..d51c3e9 --- /dev/null +++ b/internal/encoding/json/encode.go @@ -0,0 +1,1391 @@ +// Vendored from Go 1.24.0-pre-release +// To find alterations, check package shims, and comments beginning in SHIM(). +// +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package json implements encoding and decoding of JSON as defined in +// RFC 7159. The mapping between JSON and Go values is described +// in the documentation for the Marshal and Unmarshal functions. +// +// See "JSON and Go" for an introduction to this package: +// https://golang.org/doc/articles/json_and_go.html +package json + +import ( + "bytes" + "cmp" + "encoding" + "encoding/base64" + "fmt" + "github.com/ScrapeGraphAI/scrapegraph-sdk/internal/encoding/json/sentinel" + "github.com/ScrapeGraphAI/scrapegraph-sdk/internal/encoding/json/shims" + "math" + "reflect" + "slices" + "strconv" + "strings" + "sync" + "unicode" + "unicode/utf8" + _ "unsafe" // for linkname +) + +// Marshal returns the JSON encoding of v. +// +// Marshal traverses the value v recursively. +// If an encountered value implements [Marshaler] +// and is not a nil pointer, Marshal calls [Marshaler.MarshalJSON] +// to produce JSON. If no [Marshaler.MarshalJSON] method is present but the +// value implements [encoding.TextMarshaler] instead, Marshal calls +// [encoding.TextMarshaler.MarshalText] and encodes the result as a JSON string. +// The nil pointer exception is not strictly necessary +// but mimics a similar, necessary exception in the behavior of +// [Unmarshaler.UnmarshalJSON]. +// +// Otherwise, Marshal uses the following type-dependent default encodings: +// +// Boolean values encode as JSON booleans. +// +// Floating point, integer, and [Number] values encode as JSON numbers. +// NaN and +/-Inf values will return an [UnsupportedValueError]. +// +// String values encode as JSON strings coerced to valid UTF-8, +// replacing invalid bytes with the Unicode replacement rune. +// So that the JSON will be safe to embed inside HTML