diff --git a/.github/workflows/mlperf-inference-bert.yml b/.github/workflows/mlperf-inference-bert.yml new file mode 100644 index 000000000..9f08a954d --- /dev/null +++ b/.github/workflows/mlperf-inference-bert.yml @@ -0,0 +1,48 @@ +name: MLPerf inference bert (deepsparse, tf, onnxruntime, pytorch) + +on: + pull_request: + branches: [ "main", "dev" ] + paths: + - '.github/workflows/test-mlperf-inference-bert-deepsparse-tf-onnxruntime-pytorch.yml' + - '**' + - '!**.md' + +jobs: + build: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + # 3.12 didn't work on 20240305 - need to check + python-version: [ "3.11" ] + backend: [ "deepsparse", "tf", "onnxruntime", "pytorch" ] + precision: [ "int8", "fp32" ] + os: [ubuntu-latest, windows-latest, macos-latest] + exclude: + - backend: tf + - backend: pytorch + - backend: onnxruntime + - precision: fp32 + - os: windows-latest + + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + - name: Install mlcflow + run: | + python -m pip install --upgrade pip + python -m pip install --ignore-installed --verbose pip setuptools + python -m pip install . + mlc pull repo mlcommons@mlperf-automations --branch=dev + - name: Test MLPerf Inference Bert ${{ matrix.backend }} on ${{ matrix.os }} + if: matrix.os == 'windows-latest' + run: | + mlcr --tags=run,mlperf,inference,generate-run-cmds,_submission,_short --submitter="MLCommons" --hw_name=gh_${{ matrix.os }} --model=bert-99 --backend=${{ matrix.backend }} --device=cpu --scenario=Offline --test_query_count=5 --adr.loadgen.tags=_from-pip --pip_loadgen=yes --precision=${{ matrix.precision }} --target_qps=1 -v --quiet + - name: Test MLPerf Inference Bert ${{ matrix.backend }} on ${{ matrix.os }} + if: matrix.os != 'windows-latest' + run: | + mlcr --tags=run,mlperf,inference,generate-run-cmds,_submission,_short --submitter="MLCommons" --hw_name=gh_${{ matrix.os }}_x86 --model=bert-99 --backend=${{ matrix.backend }} --device=cpu --scenario=Offline --test_query_count=5 --precision=${{ matrix.precision }} --target_qps=1 -v --quiet diff --git a/.github/workflows/mlperf-inference-resnet50.yml b/.github/workflows/mlperf-inference-resnet50.yml new file mode 100644 index 000000000..6c236b7f3 --- /dev/null +++ b/.github/workflows/mlperf-inference-resnet50.yml @@ -0,0 +1,51 @@ +name: 'MLPerf inference resnet50' + +on: + pull_request: + branches: [ "main", "dev" ] + paths: + - '.github/workflows/mlperf-inference-resnet50.yml' + - '**' + - '!**.md' + +jobs: + build: + + strategy: + fail-fast: false + matrix: + python-version: ["3.12", "3.11", "3.8"] + on: [ubuntu-latest, macos-latest, windows-latest] + backend: [ "onnxruntime", "tf" ] + implementation: [ "python", "cpp" ] + exclude: + - backend: tf + implementation: cpp + - on: windows-latest + implementation: cpp + runs-on: "${{ matrix.on }}" + steps: + - uses: actions/checkout@v4 + with: + ref: ${{ github.event.pull_request.head.sha }} + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install --ignore-installed --verbose pip setuptools + python -m pip install . + mlc pull repo mlcommons@mlperf-automations --branch=dev + + - name: Test MLPerf inference ResNet50 on Windows (prebuilt loadgen) + if: runner.os == 'Windows' + run: | + mlc run script --tags=run-mlperf,inference,_submission,_short --submitter="MLCommons" --hw_name=gh_action --model=resnet50 --implementation=${{ matrix.implementation }} --backend=${{ matrix.backend }} --device=cpu --scenario=Offline --test_query_count=100 --target_qps=1 -v --quiet --adr.loadgen.tags=_from-pip --pip_loadgen=yes + + - name: Test MLPerf inference ResNet50 on Unix systems + if: runner.os != 'Windows' + run: | + mlc run script --tags=run-mlperf,inference,_submission,_short --submitter="MLCommons" --hw_name=gh_action --model=resnet50 --implementation=${{ matrix.implementation }} --backend=${{ matrix.backend }} --device=cpu --scenario=Offline --test_query_count=100 --target_qps=1 -v --quiet diff --git a/config.yaml b/config.yaml deleted file mode 100644 index db7a07d22..000000000 --- a/config.yaml +++ /dev/null @@ -1,4 +0,0 @@ -MLC_LOCAL_CACHE_FOLDER: local -local_repo_meta: - alias: local - uid: 9a3280b14a4285c9 diff --git a/mlc/__init__.py b/mlc/__init__.py new file mode 100644 index 000000000..943c20046 --- /dev/null +++ b/mlc/__init__.py @@ -0,0 +1,5 @@ +__version__ = "0.1.0" + +from .main import access + +__all__ = ['access'] diff --git a/mlc/main.py b/mlc/main.py index e47dd5846..905866ca5 100644 --- a/mlc/main.py +++ b/mlc/main.py @@ -8,11 +8,67 @@ import yaml import sys import logging +from types import SimpleNamespace import mlc.utils as utils +from pathlib import Path +from colorama import Fore, Style, init +import shutil +# Initialize colorama for Windows support +init(autoreset=True) +class ColoredFormatter(logging.Formatter): + """Custom formatter class to add colors to log levels""" + COLORS = { + 'INFO': Fore.GREEN, + 'WARNING': Fore.YELLOW, + 'ERROR': Fore.RED + } + + def format(self, record): + # Add color to the levelname + if record.levelname in self.COLORS: + record.levelname = f"{self.COLORS[record.levelname]}{record.levelname}{Style.RESET_ALL}" + return super().format(record) + -# Set up logging configuration logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') + logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + +# Create console handler with the custom formatter +console_handler = logging.StreamHandler() +console_handler.setFormatter(ColoredFormatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) + +# Remove any existing handlers and add our custom handler +if logger.hasHandlers(): + logger.handlers.clear() + +logger.addHandler(console_handler) + +# # Set up logging configuration +# logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') +# logger = logging.getLogger(__name__) + + +# Set up logging configuration +def setup_logging(log_path = 'mlc',log_file = 'mlc-log.txt'): + + logFormatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + + # File hander for logging in file in the specified path + file_handler = logging.FileHandler("{0}/{1}".format(log_path, log_file)) + file_handler.setFormatter(logFormatter) + logger.addHandler(file_handler) + + # Console handler for logging on console + consoleHandler = logging.StreamHandler() + consoleHandler.setFormatter(logFormatter) + logger.addHandler(consoleHandler) + +# Testing the log +# setup_logging(log_path='.',log_file='mlc-log2.txt') +# logger = logging.getLogger(__name__) +# logger.info('This is an info message') # Base class for CLI actions class Action: @@ -20,8 +76,10 @@ class Action: cfg = None action_type = None logger = None + local_repo = None + current_repo_path = None #mlc = None - repos = [] + repos = [] #list of Repo objects def execute(self, args): raise NotImplementedError("Subclasses must implement the execute method") @@ -40,12 +98,14 @@ def access(self, options): return {'return': 1, 'error': "'action' key is required in options"} #logger.info(f"options = {options}") + #print(f"options = {options}") action_target = options.get('target') if not action_target: action_target = options.get('automation', 'script') # Default to script if not provided action_target_split = action_target.split(",") action_target = action_target_split[0] + #print(f"action_target = {action_target}") action = actions.get(action_target) #logger.info(f"action = {action}") @@ -53,11 +113,10 @@ def access(self, options): if hasattr(action, action_name): # Find the method and call it with the options method = getattr(action, action_name) + #print(f"method = {method}, action = {action}, action_name = {action_name}") result = method(self, options) #logger.info(f"result ={result}") return result - #if result['return'] > 0: - # return result else: return {'return': 1, 'error': f"'{action_name}' action is not supported for {action_target}."} else: @@ -69,8 +128,6 @@ def find_target_folder(self, target): if not os.path.exists(self.repos_path): os.makedirs(self.repos_path, exist_ok=True) for repo_dir in os.listdir(self.repos_path): - #if "mlc" not in repo_dir: - # continue repo_path = os.path.join(self.repos_path, repo_dir) if os.path.isdir(repo_path): automation_folder = os.path.join(repo_path, 'automation') @@ -84,8 +141,7 @@ def find_target_folder(self, target): def load_repos_and_meta(self): repos_list = [] - repos_file_dir = os.path.dirname(self.repos_path) - repos_file_path = os.path.join(repos_file_dir, 'repos.json') + repos_file_path = os.path.join(self.repos_path, 'repos.json') # Read the JSON file line by line try: @@ -93,37 +149,50 @@ def load_repos_and_meta(self): with open(repos_file_path, 'r') as file: repo_paths = json.load(file) # Load the JSON file into a list except json.JSONDecodeError as e: - logger.info(f"Error decoding JSON: {e}") + logger.error(f"Error decoding JSON: {e}") return [] except FileNotFoundError: - logger.info(f"Error: File {repos_file_path} not found.") + logger.error(f"Error: File {repos_file_path} not found.") return [] except Exception as e: - logger.info(f"Error reading file: {e}") + logger.error(f"Error reading file: {e}") return [] + + def is_curdir_inside_path(base_path): + # Convert to absolute paths + base_path = Path(base_path).resolve() + curdir = Path.cwd().resolve() + + # Check if curdir is inside base_path + return curdir in base_path.parents or curdir == base_path + # Iterate through the list of repository paths for repo_path in repo_paths: + if is_curdir_inside_path(repo_path): + self.current_repo_path = repo_path repo_path = repo_path.strip() # Remove any extra whitespace or newlines # Skip empty lines if not repo_path: continue - cmr_yaml_path = os.path.join(repo_path, "cmr.yaml") + meta_yaml_path = os.path.join(repo_path, "meta.yaml") - # Check if cmr.yaml exists - if not os.path.isfile(cmr_yaml_path): - logger.info(f"Warning: {cmr_yaml_path} not found. Skipping...") + # Check if meta.yaml exists + if not os.path.isfile(meta_yaml_path): + logger.warning(f"Warning: {meta_yaml_path} not found. Skipping...") continue # Load the YAML file try: - with open(cmr_yaml_path, 'r') as yaml_file: + with open(meta_yaml_path, 'r') as yaml_file: meta = yaml.safe_load(yaml_file) except yaml.YAMLError as e: - logger.info(f"Error loading YAML in {cmr_yaml_path}: {e}") + logger.error(f"Error loading YAML in {meta_yaml_path}: {e}") continue + if meta['alias'] == "local": + self.local_repo = (meta['alias'], meta['uid']) # Create a Repo object and add it to the list repos_list.append(Repo(path=repo_path, meta=meta)) @@ -131,13 +200,13 @@ def load_repos_and_meta(self): return repos_list def load_repos(self): + # todo: what if the repo is already found in the repos folder but not registered and we pull the same repo # Get the path to the repos.json file in $HOME/MLC - repos_file_dir = os.path.dirname(self.repos_path) - repos_file_path = os.path.join(repos_file_dir, 'repos.json') + repos_file_path = os.path.join(self.repos_path, 'repos.json') # Check if the file exists if not os.path.exists(repos_file_path): - logger.info(f"Error: File not found at {repos_file_path}") + logger.error(f"Error: File not found at {repos_file_path}") return None # Load and parse the JSON file @@ -146,27 +215,98 @@ def load_repos(self): repos = json.load(file) return repos except json.JSONDecodeError as e: - logger.info(f"Error decoding JSON: {e}") + logger.error(f"Error decoding JSON: {e}") return None except Exception as e: - logger.info(f"Error reading file: {e}") + logger.error(f"Error reading file: {e}") return None + + def conflicting_repo(self, repo_meta): + for repo_object in self.repos: + if repo_object.meta.get('uid', '') == '': + return {"return": 1, "error": f"UID is not present in file 'meta.yaml' in the repo path {repo_object.path}"} + if repo_meta["uid"] == repo_object.meta.get('uid', ''): + #print(f"{repo_meta['path']} {repo_object.path}") + if repo_meta['path'] == repo_object.path: + return {"return": 1, "error": f"Same repo is already registered"} + else: + return {"return": 1, "error": f"Conflicting with repo in the path {repo_object.path}", "conflicting_path": repo_object.path} + return {"return": 0} + + def register_repo(self, repo_meta): + # Get the path to the repos.json file in $HOME/MLC + repos_file_path = os.path.join(self.repos_path, 'repos.json') + + with open(repos_file_path, 'r') as f: + repos_list = json.load(f) + + new_repo_path = repo_meta.get('path') + if new_repo_path and new_repo_path not in repos_list: + repos_list.append(new_repo_path) + logger.info(f"Added new repo path: {new_repo_path}") + + with open(repos_file_path, 'w') as f: + json.dump(repos_list, f, indent=2) + logger.info(f"Updated repos.json at {repos_file_path}") + + def unregister_repo(self, repo_path): + logger.info(f"Unregistering the repo in path {repo_path}") + repos_file_path = os.path.join(self.repos_path, 'repos.json') + + with open(repos_file_path, 'r') as f: + repos_list = json.load(f) + + if repo_path in repos_list: + repos_list.remove(repo_path) + with open(repos_file_path, 'w') as f: + json.dump(repos_list, f, indent=2) + logger.info(f"Path: {repo_path} has been removed.") + else: + logger.info(f"Path: {repo_path} not found in {repos_file_path}. Nothing to be unregistered!") + def __init__(self): self.logger = logging.getLogger() self.repos_path = os.environ.get('MLC_REPOS', os.path.expanduser('~/MLC/repos')) + ''' res = self.access({'action': 'load', 'automation': 'cfg,88dce9c160324c5d', 'item': 'default'}) - if res['return'] > 0: - return res - mlc_local_cache_path = os.path.join(self.repos_path, self.cfg['MLC_LOCAL_CACHE_FOLDER']) - if not os.path.exists(mlc_local_cache_path): - os.makedirs(mlc_local_cache_path, exist_ok=True) + + #if res['return'] > 0: + # return res + if self.cfg: + mlc_local_repo_path = os.path.join(self.repos_path, self.cfg.get('MLC_LOCAL_REPO_FOLDER', 'local')) + else: + ''' + mlc_local_repo_path = os.path.join(self.repos_path, 'local') + + mlc_local_repo_path_expanded = Path(mlc_local_repo_path).expanduser().resolve() + + if not os.path.exists(mlc_local_repo_path): + os.makedirs(mlc_local_repo_path, exist_ok=True) + + if not os.path.isfile(os.path.join(mlc_local_repo_path, "meta.yaml")): + local_repo_meta = {"alias": "local", "name": "MLC local repository", "uid": utils.get_new_uid()['uid']} + with open(os.path.join(mlc_local_repo_path, "meta.yaml"), "w") as json_file: + json.dump(local_repo_meta, json_file, indent=4) + + # TODO: what if user changes the mlc local repo path in between + repo_json_path = os.path.join(self.repos_path, "repos.json") + if not os.path.exists(repo_json_path): + with open(repo_json_path, 'w') as f: + json.dump([str(mlc_local_repo_path_expanded)], f, indent=2) + logger.info(f"Created repos.json in {os.path.dirname(self.repos_path)} and initialised with local cache folder path: {mlc_local_repo_path}") + + self.local_cache_path = os.path.join(mlc_local_repo_path, "cache") + if not os.path.exists(self.local_cache_path): + os.makedirs(self.local_cache_path, exist_ok=True) + self.repos = self.load_repos_and_meta() #logger.info(f"In Action class: {self.repos_path}") self.index = Index(self.repos_path, self.repos) + #self.repos = { #'lst': repo_paths #} @@ -189,10 +329,7 @@ def add(self, i): # Determine repository item_repo = i.get("item_repo") if not item_repo: - item_repo = ( - self.cfg["local_repo_meta"]["alias"], - self.cfg["local_repo_meta"]["uid"], - ) + item_repo = self.local_repo # Parse item details item = i.get("item") @@ -224,25 +361,34 @@ def add(self, i): return res # Determine paths and metadata format - repo_path = res["path"] - repo_meta = { - 'alias': item_repo[0], - 'uid' : item_repo[1], - } - target_path = os.path.join(repo_path, self.action_type) - if self.action_type == "cache1": - folder_name = f"""{i["script_alias"]}_{item_name or item_id}""" if i.get("script_alias") else item_name or item_id + repo = res["list"][0] + repo_path = repo.path + + target_name = i.get('target_name', self.action_type) + target_path = os.path.join(repo_path, target_name) + if target_name == "cache": + folder_name = f"""{i["script_alias"]}_{item_name or item_id[:8]}""" if i.get("script_alias") else item_name or item_id else: folder_name = item_name or item_id item_path = os.path.join(target_path, folder_name) - meta_format = "yaml" if i.get("yaml") else "json" - item_meta_path = os.path.join(item_path, f"_cm.{meta_format}") # Create item directory if it does not exist if not os.path.exists(item_path): os.makedirs(item_path) + res = self.save_new_meta(i, item_id, item_name, target_name, item_path, repo) + if res['return'] > 0: + return res + + return { + "return": 0, + "message": f"Item successfully added at {item_path}", + "path": item_path, + "repo": repo + } + + def save_new_meta(self, i, item_id, item_name, target_name, item_path, repo): # Prepare metadata item_meta = i.get('meta') item_meta.update({ @@ -256,6 +402,9 @@ def add(self, i): item_meta["tags"] = list(set(tags + new_tags)) # Ensure unique tags # Save metadata + meta_format = "yaml" if i.get("yaml") else "json" + item_meta_path = os.path.join(item_path, f"meta.{meta_format}") + if meta_format == "yaml": save_result = utils.save_yaml(item_meta_path, meta=item_meta) else: @@ -264,15 +413,8 @@ def add(self, i): if save_result["return"] > 0: return save_result - self.index.add(item_meta, self.action_type, item_path, repo_meta, repo_path) - - return { - "return": 0, - "message": f"Item successfully added at {item_path}", - "path": item_path, - "meta": item_meta, - "repo": {"uid": repo_meta['uid'], "alias": repo_meta['alias'], "path": repo_path} - } + self.index.add(item_meta, target_name, item_path, repo) + return {'return': 0} def update(self, i): """ @@ -287,7 +429,10 @@ def update(self, i): dict: Return code and message. """ # Step 1: Search for items based on input tags + target_name = i.get('target_name',"cache") + i['target_name'] = target_name ii = i.copy() + if i.get('search_tags'): ii['tags'] = ",".join(i['search_tags']) search_result = self.search(ii) @@ -320,7 +465,7 @@ def update(self, i): for item in found_items: meta = {} # Load the current meta of the item - item_meta_path = os.path.join(item.path, "_cm.json") + item_meta_path = os.path.join(item.path, "meta.json") if os.path.exists(item_meta_path): res = utils.load_json(item_meta_path) if res['return']> 0: @@ -340,32 +485,149 @@ def update(self, i): #print(f"item.meta = {item.meta}, saved_meta = {saved_meta}") save_result = utils.save_json(item_meta_path, meta=meta) #print(f"item_meta = {item.meta}, path = {item.path}") + #print(item.repo_path) + #return {'return': 1} + self.index.update(meta, target_name, item.path, item.repo) return {'return': 0, 'message': f"Tags updated successfully for {len(found_items)} item(s).", 'list': found_items } + def is_uid(self, name): + """ + Checks if the given name is a 16-digit hexadecimal UID. + + Args: + name (str): The string to check. + + Returns: + bool: True if the name is a 16-digit hexadecimal UID, False otherwise. + """ + # Define a regex pattern for a 16-digit hexadecimal UID + hex_uid_pattern = r"^[0-9a-fA-F]{16}$" + + # Check if the name matches the pattern + return bool(re.fullmatch(hex_uid_pattern, name)) + + def cp(self, run_args): + action_target = run_args['target'] + src_item = run_args['src'] + target_item = run_args['dest'] + src_split = src_item.split(":") + target_split = target_item.split(":") + if len(src_split) > 1: + src_repo = src_split[0] + src_item = src_split[1] + else: + src_item = src_split[0] + + inp = {} + inp['alias'] = src_item + inp['folder_name'] = src_item #we dont know if the user gave the alias or the folder name, we first check for alias and then the folder name + if self.is_uid(src_item): + inp['uid'] = src_item + + inp['target_name'] = action_target + res = self.search(inp) + + if len(res['list']) == 0: + return {'return': 1, 'error': f'No {action_target} found for {src_item}'} + elif len(res['list']) > 1: + return {'return': 1, 'error': f'More than 1 {action_target} found for {src_item}: {res["list"]}'} + else: + result = res['list'][0] + src_item_path = result.path + src_item_meta = result.meta + + + if len(target_split) > 1: + target_repo = target_split[0] + target_repo_path = os.path.join(self.repo_path, target_repo) + target_repo = Repo(target_repo_path) + target_item_name = target_split[1] + else: + target_repo = result.repo + target_repo_path = result.repo.path + target_item_name = target_split[0] + + target_item_path = os.path.join(target_repo_path, action_target, target_item) + #print(f"src_path = {src_item_path}, target_item = {target_item_name}, target_item_path = {target_item_path}, target_repo = {target_repo}") + res = self.copy_item(src_item_path, target_item_path) + if res['return'] > 0: + return res + + ii = {} + ii['meta'] = result.meta + if action_target == "script": + ii['yaml'] = True + + tags = run_args.get('tags') + item_id = run_args.get('item_id') + + if tags: + ii['tags'] = tags + + # Generate a new UID if not provided + if not item_id: + res = utils.get_new_uid() + if res['return'] > 0: + return res + item_id = res['uid'] + + res = self.save_new_meta(ii, item_id, target_item_name, action_target, target_item_path, target_repo) + + if res['return'] > 0: + return res + logging.info(f"{action_target} {src_item_path} copied to {target_item_path}") + + return {'return': 0} + + def copy_item(self, source_path, destination_path): + try: + # Copy the source folder to the destination + shutil.copytree(source_path, destination_path) + logging.info(f"Folder successfully copied from {source_path} to {destination_path}") + except FileExistsError: + return {'return': 1, 'error': f"Destination folder {destination_path} already exists."} + except FileNotFoundError: + return {'return': 1, 'error': f"Source folder {source_path} not found"} + except Exception as e: + return {'return': 1, 'error': f"An error occurred {e}"} + return {'return': 0} def search(self, i): indices = self.index.indices - target_index = indices.get(self.action_type) + #print(f"search input = {i}") + target = i.get('target_name', self.action_type) + target_index = indices.get(target) result = [] uid = i.get("uid") + alias = i.get("alias") + folder_name = i.get("folder_name") + found = False if target_index: - if uid: + if uid or alias: for res in target_index: - if res["uid"] == uid: + if res["uid"] == uid or (alias and res["alias"] == alias): it = Item(res['path'], res['repo']) result.append(it) + found = True + if not found and folder_name: + for res in target_index: + if os.path.basename(res["path"]) == folder_name: + it = Item(res['path'], res['repo']) + #result.append(it) else: tags= i.get("tags") tags_split = tags.split(",") - n_tags = [p for p in tags_split if p.startswith("-")] - p_tags = list(set(tags_split) - set(n_tags)) + n_tags_ = [p for p in tags_split if p.startswith("-")] + n_tags = [p[1:] for p in n_tags_] + p_tags = list(set(tags_split) - set(n_tags_)) for res in target_index: c_tags = res["tags"] if set(p_tags).issubset(set(c_tags)) and set(n_tags).isdisjoint(set(c_tags)): it = Item(res['path'], res['repo']) result.append(it) + #print(f"Search result for target {target} = {result}") return {'return': 0, 'list': result} @@ -390,7 +652,7 @@ def __init__(self, repos_path, repos): self.indices = {key: [] for key in self.index_files.keys()} self.build_index() - def add(self, meta, folder_type, path, repo_meta, repo_path): + def add(self, meta, folder_type, path, repo): unique_id = meta['uid'] alias = meta['alias'] tags = meta['tags'] @@ -399,9 +661,32 @@ def add(self, meta, folder_type, path, repo_meta, repo_path): "tags": tags, "alias": alias, "path": path, - "repo": {"uid": repo_meta['uid'], "alias": repo_meta['alias'], "path": repo_path} + "repo": repo }) + def get_index(self, folder_type, uid): + for index in range(len(self.indices[folder_type])): + if self.indices[folder_type][index]["uid"] == uid: + return index + return -1 + + def update(self, meta, folder_type, path, repo): + uid = meta['uid'] + alias = meta['alias'] + tags = meta['tags'] + index = self.get_index(folder_type, uid) + if index == -1: #add it + self.add(meta, folder_type, path, repo) + logger.debug(f"Index update failed, new index created for {uid}") + else: + self.indices[folder_type][index] = { + "uid": uid, + "tags": tags, + "alias": alias, + "path": path, + "repo": repo + } + def build_index(self): """ Build shared indices for script, cache, and experiment folders across all repositories. @@ -429,17 +714,17 @@ def build_index(self): if not os.path.isdir(automation_path): continue - # Check for configuration files (_cm.yaml or _cm.json) - for config_file in ["_cm.yaml", "_cm.json"]: + # Check for configuration files (meta.yaml or meta.json) + for config_file in ["meta.yaml", "meta.json"]: config_path = os.path.join(automation_path, config_file) if os.path.isfile(config_path): - self._process_config_file(config_path, folder_type, automation_path, repo.path, repo.meta) + self._process_config_file(config_path, folder_type, automation_path, repo) break # Only process one config file per automation_dir self._save_indices() - def _process_config_file(self, config_file, folder_type, folder_path, repo_path, repo_meta): + def _process_config_file(self, config_file, folder_type, folder_path, repo): """ - Process a single configuration file (_cm.json or _cm.yaml) and add its data to the corresponding index. + Process a single configuration file (meta.json or meta.yaml) and add its data to the corresponding index. Args: config_file (str): Path to the configuration file. @@ -473,47 +758,13 @@ def _process_config_file(self, config_file, folder_type, folder_path, repo_path, "tags": tags, "alias": alias, "path": folder_path, - "repo": {"uid": repo_meta['uid'], "alias": repo_meta['alias'], "path": repo_path} + "repo": repo }) else: logger.info(f"Skipping {config_file}: Missing 'uid' field.") except Exception as e: - logger.info(f"Error processing {config_file}: {e}") - - ''' - def _process_yaml_file(self, yaml_file, folder_type, folder_path): - """ - Process a single _cm.yaml file and add its data to the corresponding index. - - Args: - yaml_file (str): Path to the YAML file. - folder_type (str): Type of folder (script, cache, or experiment). - folder_path (str): Path to the folder containing the YAML file. - - Returns: - None - """ - try: - #logger.info(f"yaml file = {yaml_file}") - with open(yaml_file, "r") as f: - data = yaml.safe_load(f) + logger.error(f"Error processing {config_file}: {e}") - unique_id = data.get("uid") - tags = data.get("tags", []) - alias = data.get("alias", None) - - if unique_id: - self.indices[folder_type].append({ - "uid": unique_id, - "tags": tags, - "alias": alias, - "path": folder_path - }) - else: - logger.info(f"Skipping {yaml_file}: Missing 'id' field.") - except Exception as e: - logger.info(f"Error processing {yaml_file}: {e}") - ''' def _save_indices(self): """ @@ -527,36 +778,33 @@ def _save_indices(self): output_file = self.index_files[folder_type] try: with open(output_file, "w") as f: - json.dump(index_data, f, indent=4) + json.dump(index_data, f, indent=4, cls=CustomJSONEncoder) logger.info(f"Shared index for {folder_type} saved to {output_file}.") except Exception as e: - logger.info(f"Error saving shared index for {folder_type}: {e}") - + logger.error(f"Error saving shared index for {folder_type}: {e}") + +class CustomJSONEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, Repo): + # Customize how to serialize the Repo object + return { + "path": obj.path, + "meta": obj.meta, + } + # For other unknown types, use the default behavior + return super().default(obj) class Item: def __init__(self, path, repo): self.meta = None self.path = path - self.repo_meta = None - self.repo_path = repo['path'] - self._load_repo_meta() + self.repo = repo self._load_meta() - def _load_repo_meta(self): - yaml_file = os.path.join(self.repo_path, "cmr.yaml") - - json_file = os.path.join(self.repo_path, "cmr.json") - - if os.path.exists(yaml_file): - self.repo_meta = utils.read_yaml(yaml_file) - elif os.path.exists(json_file): - self.repo_meta = utils.read_json(json_file) - else: - logger.info(f"No meta file found in {self.repo_path}") def _load_meta(self): - yaml_file = os.path.join(self.path, "_cm.yaml") - json_file = os.path.join(self.path, "_cm.json") + yaml_file = os.path.join(self.path, "meta.yaml") + json_file = os.path.join(self.path, "meta.json") if os.path.exists(yaml_file): self.meta = utils.read_yaml(yaml_file) @@ -569,7 +817,23 @@ def _load_meta(self): class Repo: def __init__(self, path, meta): self.path = path - self.meta = meta + if meta: + self.meta = meta + else: + self._load_meta() + + + def _load_meta(self): + yaml_file = os.path.join(self.path, "meta.yaml") + + json_file = os.path.join(self.path, "meta.json") + + if os.path.exists(yaml_file): + self.repo_meta = utils.read_yaml(yaml_file) + elif os.path.exists(json_file): + self.repo_meta = utils.read_json(json_file) + else: + logger.info(f"No meta file found in {self.path}") class Automation: action_object = None @@ -585,8 +849,8 @@ def __init__(self, action, automation_type, automation_file): self._load_meta() def _load_meta(self): - yaml_file = os.path.join(self.path, "_cm.yaml") - json_file = os.path.join(self.path, "_cm.json") + yaml_file = os.path.join(self.path, "meta.yaml") + json_file = os.path.join(self.path, "meta.json") if os.path.exists(yaml_file): self.meta = utils.read_yaml(yaml_file) @@ -623,14 +887,21 @@ def search(self, i): # Extends Action class class RepoAction(Action): - def find(self, args): - repo = args.get('item') - repo_uid = repo.split(",")[1] - #print(f"args = {args}") + def find(self, run_args): + repo = run_args.get('item', run_args.get('artifact')) + repo_split = repo.split(",") + if len(repo_split) > 1: + repo_uid = repo_split[1] + repo_name = repo_split[0] + + lst = [] for i in self.repos: - if i.meta['uid'] == repo_uid: - return {'return': 0, 'path': i.path} - return {'return': 1, 'error': f'No repo found for uid {repo_uid}'} + if repo_uid and i.meta['uid'] == repo_uid: + lst.append(i) + elif repo_name == i.meta['alias']: + lst.append(i) + + return {'return': 0, 'list': lst} def github_url_to_user_repo_format(self, url): """ @@ -647,21 +918,11 @@ def github_url_to_user_repo_format(self, url): match = re.match(pattern, url) if match: user, repo_name = match.groups() - return f"{user}@{repo_name}" + return {"return": 0, "value": f"{user}@{repo_name}"} else: - raise ValueError("Invalid GitHub URL format") - - def pull(self, args): - repo_url = args.details if args.details else args.target_or_url - branch = None - checkout = None - extras = args.extra - for item in extras: - split = item.split("=") - if split[0] == "--branch": - branch = split[1] - elif split[0] == "--checkout": - checkout = split[1] + return {"return": 1, "error": f"Invalid GitHub URL format: {url}"} + + def pull_repo(self, repo_url, branch=None, checkout = None): # Determine the checkout path from environment or default repo_base_path = self.repos_path # either the value will be from 'MLC_REPOS' @@ -674,8 +935,12 @@ def pull(self, args): # Extract the repo name from URL repo_name = repo_url.split('/')[-1].replace('.git', '') - repo_download_name = self.github_url_to_user_repo_format(repo_url) - repo_path = os.path.join(repo_base_path, repo_download_name) + res = self.github_url_to_user_repo_format(repo_url) + if res["return"] > 0: + return res + else: + repo_download_name = res["value"] + repo_path = os.path.join(repo_base_path, repo_download_name) try: # If the directory doesn't exist, clone it @@ -688,6 +953,7 @@ def pull(self, args): clone_command = ['git', 'clone', '--branch', branch, repo_url, repo_path] subprocess.run(clone_command, check=True) + else: logger.info(f"Repository {repo_name} already exists at {repo_path}. Pulling latest changes...") subprocess.run(['git', '-C', repo_path, 'pull'], check=True) @@ -698,15 +964,102 @@ def pull(self, args): subprocess.run(['git', '-C', repo_path, 'checkout', checkout], check=True) logger.info("Repository successfully pulled.") + logger.info("Registering the repo in repos.json") + + # check the meta file to obtain uids + meta_file_path = os.path.join(repo_path, 'meta.yaml') + if not os.path.exists(meta_file_path): + logger.warning(f"meta.yaml not found in {repo_path}. Repo pulled but not register in mlc repos. Skipping...") + return {"return": 0} + + with open(meta_file_path, 'r') as meta_file: + meta_data = yaml.safe_load(meta_file) + meta_data["path"] = repo_path + + # Check UID conflicts + is_conflict = self.conflicting_repo(meta_data) + if is_conflict['return'] > 0: + if "UID not present" in is_conflict['error']: + logger.warning(f"UID not found in meta.yaml at {repo_path}. Repo pulled but can not register in mlc repos. Skipping...") + return {"return": 0} + elif "already registered" in is_conflict["error"]: + #logger.warning(is_conflict["error"]) + logger.info("No changes made to repos.json.") + return {"return": 0} + else: + logger.warning(f"The repo to be cloned has conflict with the repo already in the path: {is_conflict['conflicting_path']}") + logger.warning(f"The repo currently being pulled will be registered in repos.json and already existing one would be unregistered.") + self.unregister_repo(is_conflict['conflicting_path']) + self.register_repo(meta_data) + return {"return": 0} + else: + self.register_repo(meta_data) + return {"return": 0} + except subprocess.CalledProcessError as e: - logger.info(f"Git command failed: {e}") + return {'return': 1, 'error': f"Git command failed: {e}"} except Exception as e: - logger.info(f"Error pulling repository: {str(e)}") + return {'return': 1, 'error': f"Error pulling repository: {str(e)}"} + + def pull(self, run_args): + repo_url = run_args.get('repo', 'repo') + if repo_url == "repo": + for repo_object in self.repos: + repo_folder_name = os.path.basename(repo_object.path) + if "@" in repo_folder_name: + res = self.pull_repo(repo_folder_name) + if res['return'] > 0: + return res + else: + branch = run_args.get('branch') + checkout = run_args.get('checkout') + + res = self.pull_repo(repo_url, branch, checkout) + if res['return'] > 0: + return res + + return {'return': 0} + - def list(self, args): + def list(self, run_args): logger.info("Listing all repositories.") + print("\nRepositories:") + print("-------------") + for repo_object in self.repos: + print(f"- Alias: {repo_object.meta.get('alias', 'Unknown')}") + print(f" Path: {repo_object.path}\n") + print("-------------") + logger.info("Repository listing ended") + return {"return": 0} + + def rm(self, run_args): + logger.info("rm command has been called for repo. This would delete the repo folder and unregister the repo from repos.json") + + if not run_args['repo']: + logger.error("The repository to be removed is not specified") + return {"return": 1, "error": "The repository to be removed is not specified"} + + repo_folder_name = run_args['repo'] + + repo_path = os.path.join(self.repos_path, repo_folder_name) + + if os.path.exists(repo_path): + shutil.rmtree(repo_path) + logger.info(f"Repo {run_args['repo']} residing in path {repo_path} has been successfully removed") + logger.info("Checking whether the repo was registered in repos.json") + else: + logger.warning(f"Repo {run_args['repo']} was not found in the repo folder. repos.json will be checked for any corrupted entry. If any, that will be removed.") + + self.unregister_repo(repo_path) + + return {"return": 0} + class ScriptAction(Action): + def search(self, i): + if not i.get('target_name'): + i['target_name'] = "script" + return super().search(i) def dynamic_import_module(self, script_path): # Validate the script_path @@ -731,57 +1084,45 @@ def dynamic_import_module(self, script_path): return module - def update_script_run_args(self, run_args, inp): - for key in inp: - if "=" in key: - split = key.split("=") - run_args[split[0].strip("-")] = split[1] - elif key.startswith("-"): - run_args[key.strip("-")] = True - - def run(self, args): + def call_script_module_function(self, function_name, run_args): self.action_type = "script" - #logger.info(f"Running script with identifier: {args.details}") - # The REPOS folder is set by the user, for example via an environment variable. repos_folder = self.repos_path - logger.info(f"In script action {repos_folder}") # Import script submodule script_path = self.find_target_folder("script") module_path = os.path.join(script_path, "module.py") module = self.dynamic_import_module(module_path) - if "tags" in args: # # called through access function - tags = args["tags"] - cmd = args - run_args = args - else: - tags = "" - for option in args.extra: - opt = option.split("=") - if opt[0] == "--tags": - tags = opt[1] - cmd = args.extra - - run_args = {'action': 'run', 'automation': 'script', 'tags': tags, 'cmd': cmd, 'out': 'con', 'parsed_automation': [('script', '5b4e0237da074764')]} - # update the run args with the extras that are supplied - self.update_script_run_args(run_args, args.extra) # Check if ScriptAutomation is defined in the module if hasattr(module, 'ScriptAutomation'): automation_instance = module.ScriptAutomation(self, module_path) - logger.info(f" script automation initialized at {module_path}") - #logger.info(run_args) - result = automation_instance.run(run_args) # Pass args to the run method - #logger.info(result) + if function_name == "run": + result = automation_instance.run(run_args) # Pass args to the run method + elif function_name == "docker": + result = automation_instance.docker(run_args) # Pass args to the run method + elif function_name == "test": + result = automation_instance.test(run_args) # Pass args to the run method + else: + return {'return': 1, 'error': f'Function {function_name} is not supported'} + if result['return'] > 0: error = result.get('error', "") - raise ScriptExecutionError(f"Script execution failed. Error : {error}") - #logger.info(f"Script result: {result}") + raise ScriptExecutionError(f"Script docker execution failed. Error : {error}") return result else: logger.info("ScriptAutomation class not found in the script.") + def docker(self, run_args): + return self.call_script_module_function("docker", run_args) + + def run(self, run_args): + return self.call_script_module_function("run", run_args) + + def test(self, run_args): + return self.call_script_module_function("test", run_args) + + def list(self, args): logger.info("Listing all scripts.") @@ -791,32 +1132,23 @@ class ScriptExecutionError(Exception): pass class CacheAction(Action): + + def search(self, i): + i['target_name'] = "cache" + logger.debug(f"Searching for cache with input: {i}") + return super().search(i) + def show(self, args): self.action_type = "cache" logger.info(f"Showing cache with identifier: {args.details}") - def find(self, args): - self.action_type = "cache" + def find(self, run_args): #logger.info(f"Running script with identifier: {args.details}") # The REPOS folder is set by the user, for example via an environment variable. #logger.info(f"In cache action {repos_folder}") - - if "tags" in args: # access function - tags = args["tags"] - cmd = args - run_args = args - else: - tags = "" - for option in args.extra: - opt = option.split("=") - if opt[0] == "--tags": - tags = opt[1] - cmd = args.extra - - run_args = {'action': 'run', 'automation': 'script', 'tags': tags, 'cmd': cmd, 'out': 'con', 'parsed_automation': [('cache', '541d6f712a6b464e')]} - #self.update_script_run_args(run_args, args.extra) - + run_args['target_name'] = "cache" + #print(f"run_args = {run_args}") return self.search(run_args) def list(self, args): @@ -839,10 +1171,11 @@ def load(self, args): args (dict): Contains the configuration details such as file path, etc. """ #logger.info("In cfg load") - config_file = args.get('config_file', 'config.yaml') + default_config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.yaml') + config_file = args.get('config_file', default_config_path) logger.info(f"In cfg load, config file = {config_file}") if not config_file or not os.path.exists(config_file): - logger.info(f"Error: Configuration file '{config_file}' not found.") + logger.error(f"Error: Configuration file '{config_file}' not found.") return {'return': 1, 'error': f"Error: Configuration file '{config_file}' not found."} #logger.info(f"Loading configuration from {config_file}") @@ -855,7 +1188,7 @@ def load(self, args): # Store configuration in memory or perform other operations self.cfg = config_data except yaml.YAMLError as e: - logger.info(f"Error loading YAML configuration: {e}") + logger.error(f"Error loading YAML configuration: {e}") return {'return': 0, 'config': self.cfg} @@ -870,9 +1203,7 @@ def unload(self, args): logger.info(f"Unloading configuration.") del self.config # Remove the loaded config from memory else: - logger.info("Error: No configuration is currently loaded.") - - + logger.error("Error: No configuration is currently loaded.") actions = { 'repo': RepoAction, @@ -887,6 +1218,25 @@ def get_action(target): action_class = actions.get(target, None) return action_class() if action_class else None + +def access(i): + action = i['action'] + target = i.get('target', i['automation']) + action_class = get_action(target) + r = action_class.access(i) + return r + +def mlcr(): + first_arg_value = "run" + second_arg_value = "script" + + # Insert the positional argument into sys.argv for the main function + sys.argv.insert(1, first_arg_value) + sys.argv.insert(2, second_arg_value) + + # Call the main function + main() + # Main CLI function def main(): parser = argparse.ArgumentParser(prog='mlc', description='A CLI tool for managing repos, scripts, and caches.') @@ -895,75 +1245,65 @@ def main(): # The chosen subcommand will be stored in the "command" attribute of the parsed arguments. subparsers = parser.add_subparsers(dest='command', required=True) - # Pull parser - handles repo URLs directly - # The chosen subcommand will be stored in the "pull" attribute of the parsed arguments. - pull_parser = subparsers.add_parser('pull', help='Pull a repository by URL or target.') - pull_parser.add_argument('target_or_url', help='Target (repo) or URL for the repository.') - - pull_parser.add_argument('details', nargs='?', help='Optional details or identifier.') - pull_parser.add_argument('extra', nargs=argparse.REMAINDER, help='Extra options (e.g., -v)') + for action in ['pull']: + # Pull parser - handles repo URLs directly + # The chosen subcommand will be stored in the "pull" attribute of the parsed arguments. + pull_parser = subparsers.add_parser('pull', help='Pull a repository by URL or target.') + pull_parser.add_argument('target', choices=['repo'], help='Target type (repo).') + pull_parser.add_argument('repo', nargs='?', help='Repo to pull in URL format or owner@repo_name format for github repos') + pull_parser.add_argument('extra', nargs=argparse.REMAINDER, help='Extra options (e.g., -v)') # Script and Cache-specific subcommands - for action in ['run', 'show', 'update', 'list', 'find']: - action_parser = subparsers.add_parser(action, help=f'{action.capitalize()} a target.') + for action in ['run', 'test', 'show', 'update', 'list', 'find', 'search', 'rm', 'cp', 'mv']: + action_parser = subparsers.add_parser(action, help=f'{action} a target.') action_parser.add_argument('target', choices=['repo', 'script', 'cache'], help='Target type (repo, script, cache).') # the argument given after target and before any extra options like --tags will be stored in "details" action_parser.add_argument('details', nargs='?', help='Details or identifier (optional for list).') action_parser.add_argument('extra', nargs=argparse.REMAINDER, help='Extra options (e.g., -v)') + # Script and specific subcommands + for action in ['docker', 'help']: + action_parser = subparsers.add_parser(action, help=f'{action.capitalize()} a target.') + action_parser.add_argument('target', choices=['script'], help='Target type (script).') + # the argument given after target and before any extra options like --tags will be stored in "details" + action_parser.add_argument('details', nargs='?', help='Details or identifier (optional for list).') + action_parser.add_argument('extra', nargs=argparse.REMAINDER, help='Extra options (e.g., -v)') + for action in ['load']: load_parser = subparsers.add_parser(action, help=f'{action.capitalize()} a target.') - load_parser.add_argument('target', choices=['utils', 'cfg'], help='Target type (utils, cfg).') + load_parser.add_argument('target', choices=['cfg'], help='Target type (cfg).') - for action in [ 'get_host_os_info']: - utils_parser = subparsers.add_parser(action, help=f'{action.capitalize()} a target.') - utils_parser.add_argument('target', choices=['utils'], help='Target type (utils).') # Parse arguments args = parser.parse_args() - logger.info(f"Args = {args}") - - - # Parse extra options into a dictionary - options = {} - for opt in args.extra: - if opt.startswith('--'): - # Handle --key=value (long form) - if '=' in opt: - key, value = opt.lstrip('--').split('=') - options[key] = value - else: - options[opt.lstrip('--')] = True # --key (flag without value) - elif opt.startswith('-'): - # Handle short options (-j or -xyz) - for char in opt.lstrip('-'): - options[char] = True - else: - logger.info(f"Warning: Unrecognized option '{opt}' ignored.") - - if args.command == 'pull': - # If the first argument looks like a URL, assume repo pull - if args.target_or_url.startswith("http"): - action = RepoAction() - action.pull(args) - else: - action = get_action(args.target_or_url) - if action and hasattr(action, 'pull'): - action.pull(args) - else: - logger.info(f"Error: '{args.target_or_url}' is not a valid target for pull.") + #logger.info(f"Args = {args}") + + res = utils.convert_args_to_dictionary(args.extra) + if res['return'] > 0: + return res + + run_args = res['args_dict'] + if hasattr(args, 'repo') and args.repo: + run_args['repo'] = args.repo + + if args.command in ["cp", "mv"]: + run_args['target'] = args.target + if hasattr(args, 'details') and args.details: + run_args['src'] = args.details + if hasattr(args, 'extra') and args.extra: + run_args['dest'] = args.extra[0] + + # Get the action handler for other commands + action = get_action(args.target) + # Dynamically call the method (e.g., run, list, show) + if action and hasattr(action, args.command): + method = getattr(action, args.command) + res = method(run_args) + if res['return'] > 0: + logger.error(res.get('error', f"Error in {action}")) else: - # Get the action handler for other commands - logger.info(f"Going for action = {args.target}") - action = get_action(args.target) - logger.info(f"Got action = {action}") - # Dynamically call the method (e.g., run, list, show) - if action and hasattr(action, args.command): - method = getattr(action, args.command) - method(args) - else: - logger.info(f"Error: '{args.command}' is not supported for {args.target}.") + logger.info(f"Error: '{args.command}' is not supported for {args.target}.") if __name__ == '__main__': main() diff --git a/mlc/utils.py b/mlc/utils.py index c72667352..4a701d2c9 100644 --- a/mlc/utils.py +++ b/mlc/utils.py @@ -6,10 +6,10 @@ import platform import json import yaml - -import os import uuid import shutil +import tarfile +import zipfile def generate_temp_file(i): """ @@ -88,7 +88,8 @@ def load_txt(file_name, check_if_exists=False, split=False, match_text=None, fai result = {'return': 0} if split: - result['string'] = content.splitlines() + result['list'] = content.splitlines() + result['string'] = content else: result['string'] = content @@ -100,40 +101,83 @@ def load_txt(file_name, check_if_exists=False, split=False, match_text=None, fai except Exception as e: return {'return': 1, 'error': str(e)} -''' -def load_txt(file_name, remove_after_read=False, check_if_exists=True): +def compare_versions(current_version, min_version): """ - Loads the content of a text file into a string, with the option to delete the file after reading. + Compare two semantic version strings. Args: - file_name (str): The path to the text file to read. - remove_after_read (bool): If True, the file will be removed after reading. + current_version (str): The current version string (e.g., "1.2.3"). + min_version (str): The minimum required version string (e.g., "1.0.0"). Returns: - dict: A dictionary containing: - - return (int): Return code, 0 if no error, >0 if error - - error (str): Error string if return > 0 - - string (str): The content of the file, or an empty string if there is an error. + int: -1 if current_version < min_version, + 0 if current_version == min_version, + 1 if current_version > min_version. """ try: - # Check if the file exists - if not os.path.isfile(file_name): - return {'return': 1, 'error': f"File {file_name} not found", 'string': ''} + # Use `packaging.version` to handle semantic version comparison + current = version.parse(current_version) + minimum = version.parse(min_version) + + if current < minimum: + return -1 + elif current > minimum: + return 1 + else: + return 0 + except Exception as e: + raise ValueError(f"Invalid version format: {e}") - # Read the content of the file - with open(file_name, 'r') as file: - file_content = file.read() +def run_system_cmd(i): + """ + Execute a system command in a specified path. - # Optionally remove the file after reading - if remove_after_read: - os.remove(file_name) + Args: + i (dict): A dictionary containing: + - 'path' (str): The directory to run the command in. + - 'cmd' (str): The system command to execute. + + Returns: + dict: A dictionary with the result of the execution: + - {'return': 0, 'output': } on success. + - {'return': 1, 'error': } on failure. + """ + # Extract path and cmd from the input dictionary + path = i.get('path', '.') + cmd = i.get('cmd', '') + + if not cmd: + return {'return': 1, 'error': 'No command provided to execute.'} - # Return the content in the expected dictionary format - return {'return': 0, 'error': '', 'string': file_content} + if not os.path.exists(path): + return {'return': 1, 'error': f"Specified path does not exist: '{path}'"} + # Change to the specified path and execute the command + try: + result = subprocess.run( + cmd, + cwd=path, + shell=True, + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True + ) + return { + 'return': 0, + 'output': result.stdout.strip(), + 'error_output': result.stderr.strip() + } + except subprocess.CalledProcessError as e: + return { + 'return': 1, + 'error': f"Command execution failed with error code {e.returncode}.", + 'error_output': e.stderr.strip() + } except Exception as e: - return {'return': 1, 'error': str(e), 'string': ''} -''' + return {'return': 1, 'error': f"Unexpected error occurred: {str(e)}"} + + def print_env(env): print_formatted_json(env) @@ -209,8 +253,22 @@ def merge_dicts(params, in_place=True): elif isinstance(existing_value, list) and isinstance(value, list): if append_lists: if append_unique: - # Append only unique values from the second list - merged_dict[key] = list(set(existing_value + value)) + # Combine dictionaries uniquely based on their key-value pairs + seen = set() + merged_list = [] + for item in existing_value + value: + if isinstance(item, dict): + try: + item_frozenset = frozenset(item.items()) + except TypeError: + item_frozenset = id(item) + else: + item_frozenset = item + if item_frozenset not in seen: + seen.add(item_frozenset) + merged_list.append(item) + merged_dict[key] = merged_list + else: # Simply append the values merged_dict[key] = existing_value + value @@ -248,6 +306,27 @@ def save_json(file_name, meta): except Exception as e: return {'return': 1, 'error': str(e)} + +def save_yaml(file_name, meta): + """ + Saves the provided meta data to a YAML file. + + Args: + file_name (str): The name of the file where the YAML data will be saved. + meta (dict): The dictionary containing the data to be saved in YAML format. + + Returns: + dict: A dictionary indicating success or failure of the operation. + - 'return' (int): 0 if the operation was successful, > 0 if an error occurred. + - 'error' (str): Error message, if any error occurred. + """ + try: + with open(file_name, 'w') as f: + yaml.dump(meta, f, default_flow_style=False, sort_keys=False) + return {'return': 0, 'error': ''} + except Exception as e: + return {'return': 1, 'error': str(e)} + def save_txt(file_name, string): """ Saves the provided string to a text file. @@ -268,6 +347,48 @@ def save_txt(file_name, string): except Exception as e: return {'return': 1, 'error': str(e)} + +def convert_args_to_dictionary(inp): + args_dict = {} + for key in inp: + if "=" in key: + split = key.split("=", 1) # Split only on the first "=" + arg_key = split[0].strip("-") + arg_value = split[1] + + # Handle lists: Only if "," is immediately before the "=" + if "," in arg_key: + list_key, list_values = arg_key.rsplit(",", 1) + if not list_values: # Ensure "=" follows the last comma + args_dict[list_key] = arg_value.split(",") + continue + + # Handle dictionaries: `--adr.compiler.tags=gcc` becomes `{"adr": {"compiler": {"tags": "gcc"}}}` + elif "." in arg_key: + keys = arg_key.split(".") + current = args_dict + for part in keys[:-1]: + if part not in current or not isinstance(current[part], dict): + current[part] = {} + current = current[part] + current[keys[-1]] = arg_value + + # Handle simple key-value pairs + else: + args_dict[arg_key] = arg_value + + # Handle flags: `--flag` becomes `{"flag": True}` + elif key.startswith("--"): + args_dict[key.strip("-")] = True + + # Handle short options (-j or -xyz) + elif key.startswith("-"): + for char in key.lstrip('-'): + args_dict[char] = True + + return {'return': 0, 'args_dict': args_dict} + + def sub_input(i, keys, reverse=False): """ Extracts and returns values from the dictionary based on the provided keys. @@ -394,7 +515,7 @@ def convert_env_to_dict(env_text): return {'return': 0, 'dict': env_dict} -def load_json(file_name): +def load_json(file_name, encoding = None): """ Load JSON data from a file and handle errors. @@ -408,8 +529,12 @@ def load_json(file_name): - 'meta': The loaded JSON data if successful """ try: - with open(file_name, 'r') as f: - meta = json.load(f) + if encoding: + with open(file_name, 'r', encoding=encoding) as f: + meta = json.load(f) + else: + with open(file_name, 'r') as f: + meta = json.load(f) return {'return': 0, 'meta': meta} @@ -455,3 +580,58 @@ def convert_tags_to_list(tags_string): tags_list = [tag.strip() for tag in tags_string.split(',') if tag.strip()] return {'return': 0, 'tags': tags_list} + + + +def extract_file(options): + """ + Extracts a compressed file, optionally stripping folder levels. + + Args: + options (dict): A dictionary with the following keys: + - 'filename' (str): The path to the compressed file to extract. + - 'strip_folders' (int, optional): The number of folder levels to strip. Default is 0. + + Raises: + ValueError: If the file format is unsupported. + FileNotFoundError: If the specified file does not exist. + """ + filename = options.get('filename') + strip_folders = options.get('strip_folders', 0) + + if not filename or not os.path.exists(filename): + raise FileNotFoundError(f"File not found: {filename}") + + extract_to = os.path.join(os.path.dirname(filename), "extracted") + os.makedirs(extract_to, exist_ok=True) + + # Check file type and extract accordingly + if zipfile.is_zipfile(filename): + with zipfile.ZipFile(filename, 'r') as archive: + members = archive.namelist() + for member in members: + # Strip folder levels + stripped_path = os.path.join( + extract_to, *member.split(os.sep)[strip_folders:] + ) + if member.endswith('/'): # Directory + os.makedirs(stripped_path, exist_ok=True) + else: # File + os.makedirs(os.path.dirname(stripped_path), exist_ok=True) + with archive.open(member) as source, open(stripped_path, 'wb') as target: + shutil.copyfileobj(source, target) + + elif tarfile.is_tarfile(filename): + with tarfile.open(filename, 'r') as archive: + members = archive.getmembers() + for member in members: + if strip_folders: + parts = member.name.split('/') + member.name = '/'.join(parts[strip_folders:]) + archive.extract(member, path=extract_to) + + else: + raise ValueError(f"Unsupported file format: {filename}") + + print(f"Extraction complete. Files extracted to: {extract_to}") + diff --git a/pyproject.toml b/pyproject.toml index b47594009..80d53c3c9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,12 +4,12 @@ build-backend = "setuptools.build_meta" [project] name = "mlcflow" -version = "0.1.0" +version = "0.1.8" description = "An automation interface for ML applications" authors = [ { name = "MLCommons", email = "systems@mlcommons.org" } ] -license = { file = "LICENSE" } +license = { file = "LICENSE.md" } readme = "README.md" requires-python = ">=3.7" keywords = ["mlc", "mlcflow", "pypi", "package", "automation"] @@ -20,7 +20,10 @@ classifiers = [ ] dependencies = [ - "requests" + "requests", + "pyyaml", + "giturlparse", + "colorama" ] [project.urls] @@ -31,7 +34,12 @@ Issues = "https://github.com/mlcommons/mlcflow/issues" [tool.setuptools] packages = ["mlc"] +#include-package-data = true + +#[tool.setuptools.package-data] +# This includes the config.yaml file +#"mlc" = ["config.yaml"] [project.scripts] mlc = "mlc.main:main" -cm = "mlc.main:main" +mlcr = "mlc.main:mlcr"