diff --git a/pyproject.toml b/pyproject.toml index 8de8ce2a..a9d283f5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,9 +14,9 @@ dynamic = ["version"] readme = "README.md" dependencies = [ "torch>=2.7", - "opentelemetry-exporter-otlp-proto-http>=1.37.0", - "opentelemetry-sdk>=1.37.0", - "opentelemetry-api>=1.37.0", + "opentelemetry-exporter-otlp-proto-http>=1.39.0", + "opentelemetry-sdk>=1.39.0", + "opentelemetry-api>=1.39.0", ] [project.urls] diff --git a/torchft/optim.py b/torchft/optim.py index a2884392..b05d5deb 100644 --- a/torchft/optim.py +++ b/torchft/optim.py @@ -59,5 +59,5 @@ def param_groups(self) -> List[Dict[str, Any]]: return self.optim.param_groups @property - def state(self) -> Mapping[torch.Tensor, Any]: # pyre-fixme[3] + def state(self) -> Mapping[torch.Tensor, object]: return self.optim.state diff --git a/torchft/otel.py b/torchft/otel.py index 9f927a5b..efe0a4be 100644 --- a/torchft/otel.py +++ b/torchft/otel.py @@ -13,12 +13,12 @@ from opentelemetry._logs import set_logger_provider from opentelemetry.exporter.otlp.proto.http._log_exporter import OTLPLogExporter -from opentelemetry.sdk._logs import LoggerProvider, LoggingHandler -from opentelemetry.sdk._logs._internal import LogData +from opentelemetry.sdk._logs import LoggerProvider, LoggingHandler, ReadableLogRecord from opentelemetry.sdk._logs.export import ( BatchLogRecordProcessor, - ConsoleLogExporter, - LogExporter, + ConsoleLogRecordExporter, + LogRecordExporter, + LogRecordExportResult, ) from opentelemetry.sdk.resources import Resource @@ -27,18 +27,19 @@ TORCHFT_OTEL_RESOURCE_ATTRIBUTES_JSON = "TORCHFT_OTEL_RESOURCE_ATTRIBUTES_JSON" -class TeeLogExporter(LogExporter): +class TeeLogExporter(LogRecordExporter): """Exporter that writes to multiple exporters.""" def __init__( self, - exporters: List[LogExporter], + exporters: List[LogRecordExporter], ) -> None: self._exporters = exporters - def export(self, batch: Sequence[LogData]) -> None: + def export(self, batch: Sequence[ReadableLogRecord]) -> LogRecordExportResult: for e in self._exporters: e.export(batch) + return LogRecordExportResult.SUCCESS def shutdown(self) -> None: for e in self._exporters: @@ -49,8 +50,6 @@ def setup_logger(name: str) -> None: if os.environ.get("TORCHFT_USE_OTEL", "false") == "false": return - global _LOGGER_PROVIDER - if name in _LOGGER_PROVIDER: return @@ -70,7 +69,7 @@ def setup_logger(name: str) -> None: exporter = TeeLogExporter( exporters=[ - ConsoleLogExporter(), + ConsoleLogRecordExporter(), OTLPLogExporter( timeout=5, ), diff --git a/torchft/process_group_test.py b/torchft/process_group_test.py index 2f323ad5..625799c8 100644 --- a/torchft/process_group_test.py +++ b/torchft/process_group_test.py @@ -76,7 +76,7 @@ def _test_pg( ] tensor_list = [torch.empty_like(input_tensor)] - def check_tensors(arg: Any) -> None: # pyre-ignore[2] + def check_tensors(arg: object) -> None: """Recursively check tensors for expected shape and dtype.""" if isinstance(arg, torch.Tensor): assert arg.dtype == dtype, f"Output dtype mismatch: {arg.dtype} != {dtype}" @@ -738,7 +738,10 @@ def test_functional_collectives(self) -> None: self.assertEqual(pg.group_name, str(dist.get_pg_count() - 1)) - self.assertIs(_resolve_process_group(pg.group_name), pg) + self.assertIs( + _resolve_process_group(pg.group_name), # pyre-ignore[6]: GroupName vs str + pg, + ) try: t = torch.zeros(10)