|
7 | 7 | from typing import Any, Optional |
8 | 8 |
|
9 | 9 | if typing.TYPE_CHECKING: |
| 10 | + from cloud_pipelines.orchestration.storage_providers import ( |
| 11 | + interfaces as storage_provider_interfaces, |
| 12 | + ) |
10 | 13 | from .launchers import interfaces as launcher_interfaces |
11 | 14 |
|
12 | 15 |
|
@@ -756,8 +759,16 @@ def get_container_execution_log( |
756 | 759 | raise RuntimeError( |
757 | 760 | f"Container execution {container_execution.id=} does not have log_uri. Impossible." |
758 | 761 | ) |
| 762 | + # TODO: Make the ContainerLauncher._storage_provider part of the public interface or create a better solution for log retrieval |
| 763 | + # Try getting the configured storage provider from the launcher so that it has correct access credentials. |
| 764 | + storage_provider = ( |
| 765 | + getattr(container_launcher, "_storage_provider", None) |
| 766 | + if container_launcher |
| 767 | + else None |
| 768 | + ) |
759 | 769 | log_text = _read_container_execution_log_from_uri( |
760 | | - container_execution.log_uri |
| 770 | + log_uri=container_execution.log_uri, |
| 771 | + storage_provider=storage_provider, |
761 | 772 | ) |
762 | 773 | except: |
763 | 774 | # Do not raise exception if the execution is in SYSTEM_ERROR state |
@@ -820,18 +831,41 @@ def stream_container_execution_log( |
820 | 831 | raise RuntimeError( |
821 | 832 | f"Container execution {container_execution.id=} does not have log_uri. Impossible." |
822 | 833 | ) |
| 834 | + # TODO: Make the ContainerLauncher._storage_provider part of the public interface or create a better solution for log retrieval |
| 835 | + # Try getting the configured storage provider from the launcher so that it has correct access credentials. |
| 836 | + storage_provider = ( |
| 837 | + getattr(container_launcher, "_storage_provider", None) |
| 838 | + if container_launcher |
| 839 | + else None |
| 840 | + ) |
823 | 841 | log_text = _read_container_execution_log_from_uri( |
824 | | - container_execution.log_uri |
| 842 | + log_uri=container_execution.log_uri, |
| 843 | + storage_provider=storage_provider, |
825 | 844 | ) |
826 | 845 | return (line + "\n" for line in log_text.split("\n")) |
827 | 846 |
|
828 | 847 |
|
829 | | -def _read_container_execution_log_from_uri(log_uri: str): |
830 | | - if "://" not in log_uri and ".." not in log_uri: |
| 848 | +def _read_container_execution_log_from_uri( |
| 849 | + log_uri: str, |
| 850 | + storage_provider: "storage_provider_interfaces.StorageProvider | None" = None, |
| 851 | +) -> str: |
| 852 | + if ".." in log_uri: |
| 853 | + raise ValueError( |
| 854 | + f"_read_container_execution_log_from_uri: log_uri contains '..': {log_uri=}" |
| 855 | + ) |
| 856 | + |
| 857 | + if storage_provider: |
| 858 | + # TODO: Switch to storage_provider.parse_uri_get_accessor |
| 859 | + uri_accessor = storage_provider.make_uri(log_uri) |
| 860 | + log_text = uri_accessor.get_reader().download_as_text() |
| 861 | + return log_text |
| 862 | + |
| 863 | + if "://" not in log_uri: |
831 | 864 | # Consider the URL to be an absolute local path (`/path` or `C:\path` or `C:/path`) |
832 | 865 | with open(log_uri, "r") as reader: |
833 | 866 | return reader.read() |
834 | 867 | elif log_uri.startswith("gs://"): |
| 868 | + # TODO: Switch to using storage providers. |
835 | 869 | from google.cloud import storage |
836 | 870 |
|
837 | 871 | gcs_client = storage.Client() |
|
0 commit comments