Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,23 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

[Unreleased]

### Added

- `stream_io.open_stream()` now respects Boto3's configuration files
and environment variables when searching for object storage credentials to use

### Fixed

- `stream_io.open_stream()` now uses virtual-hosted-style
bucket addressing for the `cwobject.com` and `cwlota.com` endpoints
- `stream_io.open_stream()` now allows the `use_https` entry of `.s3cfg`
configuration files to fill in its `force_http` parameter if `force_http` is
not explicitly specified as `True` or `False`
- `TensorSerializer` no longer throws an error when attempting to serialize
very large tensors on some non-Linux platforms

[2.9.3] - 2025-05-09

### Changed
Expand Down Expand Up @@ -424,6 +441,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `get_gpu_name`
- `no_init_or_tensor`

[Unreleased]: https://github.com/coreweave/tensorizer/compare/v2.9.3...HEAD
[2.9.3]: https://github.com/coreweave/tensorizer/compare/v2.9.2...v2.9.3
[2.9.2]: https://github.com/coreweave/tensorizer/compare/v2.9.1...v2.9.2
[2.9.1]: https://github.com/coreweave/tensorizer/compare/v2.9.0...v2.9.1
Expand Down
30 changes: 24 additions & 6 deletions tensorizer/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import queue
import stat
import struct
import sys
import threading
import time
import types
Expand Down Expand Up @@ -3506,16 +3507,33 @@ def _pwrite(
raise RuntimeError("pwrite was called before being initialized")

@staticmethod
def _mv_suffix(data: "collections.abc.Buffer", start: int):
def _mv_slice(data: "collections.abc.Buffer", s: slice):
if not isinstance(data, memoryview):
data = memoryview(data)
try:
if data.ndim != 1 or data.format != "B":
data = data.cast("B")
return data[start:]
return data[s]
finally:
del data

if sys.platform == "linux":
_pwrite_compat = staticmethod(os.pwrite)
else:

@staticmethod
def _pwrite_compat(_fd, _str, _offset, /):
# Some systems error on single I/O calls larger than the maximum
# value of a signed 32-bit integer, so limit os.pwrite calls
# to a maximum size of about one memory page less than that
MAX_LEN: typing.Final[int] = 2147479552

if len(_str) > MAX_LEN:
with TensorSerializer._mv_slice(_str, slice(MAX_LEN)) as mv:
return os.pwrite(_fd, mv, _offset)

return os.pwrite(_fd, _str, _offset)

def _pwrite_syscall(
self, data, offset: int, verify: Union[bool, int] = True
) -> int:
Expand All @@ -3525,14 +3543,14 @@ def _pwrite_syscall(
expected_bytes_written: int = (
verify if isinstance(verify, int) else self._buffer_size(data)
)
bytes_just_written: int = os.pwrite(self._fd, data, offset)
bytes_just_written: int = self._pwrite_compat(self._fd, data, offset)
if bytes_just_written > 0:
bytes_written += bytes_just_written
while bytes_written < expected_bytes_written and bytes_just_written > 0:
# Writes larger than ~2 GiB may not complete in a single pwrite call
offset += bytes_just_written
with self._mv_suffix(data, bytes_written) as mv:
bytes_just_written = os.pwrite(self._fd, mv, offset)
with self._mv_slice(data, slice(bytes_written, None)) as mv:
bytes_just_written = self._pwrite_compat(self._fd, mv, offset)
if bytes_just_written > 0:
bytes_written += bytes_just_written
if isinstance(verify, int) or verify:
Expand All @@ -3553,7 +3571,7 @@ def _write(self, data, expected_bytes_written: Optional[int] = None) -> int:
if bytes_just_written > expected_bytes_written:
raise ValueError("Wrote more data than expected")
while bytes_written < expected_bytes_written and bytes_just_written > 0:
with self._mv_suffix(data, bytes_written) as mv:
with self._mv_slice(data, slice(bytes_written, None)) as mv:
bytes_just_written = self._file.write(mv)
bytes_written += bytes_just_written
return bytes_written
Expand Down
147 changes: 100 additions & 47 deletions tensorizer/stream_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

import boto3
import botocore
import botocore.exceptions
import botocore.session
import redis

import tensorizer._version as _version
Expand Down Expand Up @@ -57,6 +59,7 @@ class _ParsedCredentials(typing.NamedTuple):
s3_endpoint: Optional[str]
s3_access_key: Optional[str]
s3_secret_key: Optional[str]
use_https: Optional[bool]


@functools.lru_cache(maxsize=None)
Expand Down Expand Up @@ -105,16 +108,24 @@ def _get_s3cfg_values(
if config.read((config_path,)):
break
else:
return _ParsedCredentials(None, None, None, None)
return _ParsedCredentials(None, None, None, None, None)

if "default" not in config:
raise ValueError(f"No default section in {config_path}")

use_https = config["default"].get("use_https")
if use_https == "True":
use_https = True
elif use_https == "False":
use_https = False
else:
use_https = None
return _ParsedCredentials(
config_file=os.fsdecode(config_path),
s3_endpoint=config["default"].get("host_base"),
s3_access_key=config["default"].get("access_key"),
s3_secret_key=config["default"].get("secret_key"),
use_https=use_https,
)


Expand Down Expand Up @@ -897,23 +908,24 @@ def _ensure_https_endpoint(endpoint: str):
raise ValueError("Non-HTTPS endpoint URLs are not allowed.")


def _is_caios(endpoint: str) -> bool:
host = urlparse(endpoint if "//" in endpoint else "//" + endpoint).netloc
return host.lower() in {"cwobject.com", "cwlota.com"}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really not a fan of us hardcoding these like this (or the ord1 defaults we have had elsewhere for some time). But can be convinced this is ok.

For reference to those other hardcoded values that should be either A) dropped, or B) changed to our current object storage offering.

default_s3_read_endpoint = "accel-object.ord1.coreweave.com"
default_s3_write_endpoint = "object.ord1.coreweave.com"

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm pretty sure in boto3's source tree there is a large data segment that has these sorts of hardcoded compatibility checks for a huge set of different AWS endpoints, which is how Boto3's "magically pick the right one" feature works. Unless/until we upstream a change to support our object storage endpoints correctly, which I'm not sure they'd accept, this seems to me like a fair place to put a workaround.

The other hardcoded values are necessary for the default s3://tensorized bucket to "just work" since they're in that location. I'd like to change it to the newer endpoint in release 3.0.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is a pattern already used in boto3 then probably fine... although we are not automatically changing the force_http based on if we are using cwlota.com or not which would fall into the same pattern as what we are doing here. Same points apply on both sides there as well as if a user configures it correctly it will work, but by default with a minimal amount of configuring it will not.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am nervous about defaulting force_http=True for the cwlota.com endpoint on the basis that it could transmit a presigned URL as part of a request over an unknown network in cleartext if it is used somewhere where cwlota.com doesn't resolve to a local network (so not a CW datacenter, maybe from someone's own machine where they're trying to test something), although I think the presigned URL would only be valid for use by someone who can access cwlota.com endpoints. I'd rather leave that one specifically as something that needs to be declared explicitly in the config.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As long as it works correctly when http://cwlota.com is in the config or environment variables as the endpoint then I don't strongly think it needs to have something explicitly forcing it here.



def _new_s3_client(
s3_access_key_id: str,
s3_secret_access_key: str,
s3_endpoint: str,
s3_access_key_id: Optional[str],
s3_secret_access_key: Optional[str],
s3_endpoint: Optional[str],
s3_region_name: Optional[str] = None,
s3_signature_version: Optional[str] = None,
):
if s3_secret_access_key is None:
raise TypeError("No secret key provided")
if s3_access_key_id is None:
raise TypeError("No access key provided")
if s3_endpoint is None:
raise TypeError("No S3 endpoint provided")

config_args = dict(user_agent=_BOTO_USER_AGENT)
auth_args = {}

if _is_caios(s3_endpoint):
# These endpoints don't support path-style addressing
config_args["s3"] = {"addressing_style": "virtual"}
if s3_region_name is not None:
config_args["region_name"] = s3_region_name

Expand Down Expand Up @@ -952,9 +964,9 @@ def _parse_s3_uri(uri: str) -> Tuple[str, str]:
def s3_upload(
path: str,
target_uri: str,
s3_access_key_id: str,
s3_secret_access_key: str,
s3_endpoint: str = default_s3_write_endpoint,
s3_access_key_id: Optional[str],
s3_secret_access_key: Optional[str],
s3_endpoint: Optional[str] = default_s3_write_endpoint,
s3_region_name: Optional[str] = None,
s3_signature_version: Optional[str] = None,
):
Expand All @@ -971,9 +983,9 @@ def s3_upload(

def _s3_download_url(
path_uri: str,
s3_access_key_id: str,
s3_secret_access_key: str,
s3_endpoint: str = default_s3_read_endpoint,
s3_access_key_id: Optional[str],
s3_secret_access_key: Optional[str],
s3_endpoint: Optional[str] = default_s3_read_endpoint,
s3_region_name: Optional[str] = None,
s3_signature_version: Optional[str] = None,
) -> str:
Expand Down Expand Up @@ -1021,19 +1033,40 @@ def _s3_download_url(
expiry = t - (t % SIG_GRANULARITY) + (SIG_GRANULARITY * 2)
seconds_to_expiry = expiry - t

url = client.generate_presigned_url(
ClientMethod="get_object",
Params={"Bucket": bucket, "Key": key},
ExpiresIn=seconds_to_expiry,
)
try:
# This is the first point at which an error may be raised by boto3
# for missing credentials
url = client.generate_presigned_url(
ClientMethod="get_object",
Params={"Bucket": bucket, "Key": key},
ExpiresIn=seconds_to_expiry,
)
except botocore.exceptions.NoCredentialsError:
if s3_access_key_id is None and s3_secret_access_key is None:
# Credentials may be absent because a public read
# bucket is being used, so try blank credentials
try:
return _s3_download_url(
path_uri,
"",
"",
s3_endpoint,
s3_region_name,
s3_signature_version,
)
except botocore.exceptions.NoCredentialsError:
# If this has the same error for some reason,
# just ignore it, and raise the original error
pass
raise
return url


def s3_download(
path_uri: str,
s3_access_key_id: str,
s3_secret_access_key: str,
s3_endpoint: str = default_s3_read_endpoint,
s3_access_key_id: Optional[str],
s3_secret_access_key: Optional[str],
s3_endpoint: Optional[str] = default_s3_read_endpoint,
s3_region_name: Optional[str] = None,
s3_signature_version: Optional[str] = None,
buffer_size: Optional[int] = None,
Expand Down Expand Up @@ -1101,6 +1134,7 @@ def _infer_credentials(
s3_endpoint=None,
s3_access_key=s3_access_key_id,
s3_secret_key=s3_secret_access_key,
use_https=None,
)

# Try to find default credentials if at least one is not specified
Expand Down Expand Up @@ -1149,6 +1183,7 @@ def _infer_credentials(
s3_endpoint=parsed.s3_endpoint,
s3_access_key=s3_access_key_id,
s3_secret_key=s3_secret_access_key,
use_https=parsed.use_https,
)


Expand Down Expand Up @@ -1197,7 +1232,7 @@ def open_stream(
s3_endpoint: Optional[str] = None,
s3_config_path: Optional[Union[str, bytes, os.PathLike]] = None,
buffer_size: Optional[int] = None,
force_http: bool = False,
force_http: Optional[bool] = None,
*,
begin: Optional[int] = None,
end: Optional[int] = None,
Expand Down Expand Up @@ -1364,30 +1399,48 @@ def open_stream(
# Not required to have been found,
# and doesn't overwrite an explicitly specified endpoint.
s3_endpoint = s3_endpoint or s3.s3_endpoint
except (ValueError, FileNotFoundError) as e:
# Uploads always require credentials here, but downloads may not
if is_s3_upload:
raise
else:
# Credentials may be absent because a public read
# bucket is being used, so try blank credentials,
# but provide a descriptive warning for future errors
# that may occur due to this exception being suppressed.
# Don't save the whole exception object since it holds
# a stack trace, which can interfere with garbage collection.
error_context = (
"Warning: empty credentials were used for S3."
f"\nReason: {e}"
"\nIf the connection failed due to missing permissions"
" (e.g. HTTP error 403), try providing credentials"
" directly with the tensorizer.stream_io.open_stream()"
" function."
)
s3_access_key_id = s3_access_key_id or ""
s3_secret_access_key = s3_access_key_id or ""

if force_http is None and s3.use_https is not None:
force_http = not s3.use_https
except (ValueError, FileNotFoundError):
# TODO: Reimplement this logic somewhere in s3_download
#
# Credentials may be absent because a public read
# bucket is being used, so try blank credentials,
# but provide a descriptive warning for future errors
# that may occur due to this exception being suppressed.
# Don't save the whole exception object since it holds
# a stack trace, which can interfere with garbage collection.
#
# error_context = (
# "Warning: empty credentials were used for S3."
# f"\nReason: {e}"
# "\nIf the connection failed due to missing permissions"
# " (e.g. HTTP error 403), try providing credentials"
# " directly with the tensorizer.stream_io.open_stream()"
# " function."
# )
pass

if force_http is None:
force_http = False

# Regardless of whether the config needed to be parsed,
# the endpoint gets a default value based on the operation.
# First, check for the AWS_ENDPOINT_URL environment variables,
# otherwise, check if botocore can resolve credentials,
# otherwise, use default_s3_write_endpoint and default_s3_read_endpoint.
if not s3_endpoint:
s3_endpoint = os.environ.get("AWS_ENDPOINT_URL_S3") or None
if not s3_endpoint:
s3_endpoint = os.environ.get("AWS_ENDPOINT_URL") or None
if not s3_endpoint:
scoped_config = botocore.session.Session().get_scoped_config()
s3_config = scoped_config.get("s3")
if s3_config:
s3_endpoint = s3_config.get("endpoint_url") or None
if not s3_endpoint:
s3_endpoint = scoped_config.get("endpoint_url") or None

if is_s3_upload:
s3_endpoint = s3_endpoint or default_s3_write_endpoint
Expand Down