Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions .github/workflows/tests-cpu.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@

name: Test CPU

on:
workflow_call:
inputs:
artifact-name:
description: 'Run Forge unit tests on CPUs.'
required: true
type: string

concurrency:
group: test-cpu-${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_number || github.ref }}
cancel-in-progress: true

jobs:
test-cpu-no-tensor-engine:
name: Test CPU
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
with:
timeout: 60
runner: linux.4xlarge
submodules: recursive
download-artifact: ${{ inputs.artifact-name }}
script: |
# Source common setup functions
source scripts/common-setup.sh

# Setup test environment
setup_conda_environment

pip install .[dev]
pip install .[oss]

pytest tests/unit_tests -vs
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ dependencies = [
# PyTorch
"torchdata>=0.8.0",
"torchtitan",
"torchao",
# vLLM
# TODO: pin specific vllm version
#"vllm==0.10.0",
Expand All @@ -36,6 +37,7 @@ dev = [
"pytest-cov",
"tensorboard",
"tomli>=1.1.0",
"tomli_w",
"anyio",
"pytest-asyncio",
]
Expand Down
5 changes: 5 additions & 0 deletions scripts/common-setup.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
25 changes: 23 additions & 2 deletions src/forge/actors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,27 @@
# LICENSE file in the root directory of this source tree.

from .collector import Collector
from .policy import Policy, PolicyRouter

__all__ = ["Collector", "Policy", "PolicyRouter"]
__all__ = ["Collector"]

try:
from .policy import Policy, PolicyRouter

__all__.extend(["Policy", "PolicyRouter"])
except ImportError as e:
# Create placeholder classes that give helpful error messages
class Policy:
def __init__(self, *args, **kwargs):
raise ImportError(
"Policy requires vLLM to be installed. "
"Install it with: pip install vllm"
) from e

class PolicyRouter:
def __init__(self, *args, **kwargs):
raise ImportError(
"PolicyRouter requires vLLM to be installed. "
"Install it with: pip install vllm"
) from e

__all__.extend(["Policy", "PolicyRouter"])
7 changes: 4 additions & 3 deletions src/forge/data/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,17 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from .dataset import DatasetInfo, InfiniteTuneIterableDataset
from .dataset import DatasetInfo, InfiniteTuneIterableDataset, InterleavedDataset
from .hf_dataset import HfIterableDataset
from .packed import PackedDataset
from .sft_dataset import SFTOutputTransform, sft_iterable_dataset
from .sft_dataset import sft_iterable_dataset, SFTOutputTransform

__all__ = [
"DatasetInfo",
"HfIterableDataset",
"InterleavedDataset",
"InfiniteTuneIterableDataset",
"PackedDataset",
"SFTOutputTransform",
"sft_iterable_dataset",
]
]
Loading