diff --git a/.github/ISSUE_TEMPLATE/v3_feature.md b/.github/ISSUE_TEMPLATE/v3_feature.md new file mode 100644 index 0000000..9c207ec --- /dev/null +++ b/.github/ISSUE_TEMPLATE/v3_feature.md @@ -0,0 +1,66 @@ +--- +name: v3.0.0 Feature Implementation +about: Track implementation of a v3.0.0 feature +title: '[v3.0.0] ' +labels: 'enhancement, v3.0.0' +assignees: '' + +--- + +## Feature Name + + +## Phase + + +## Branch Name + + +## Description + + +## Implementation Checklist + +### Code Implementation +- [ ] Core functionality implemented +- [ ] Data models created +- [ ] Error handling added +- [ ] Type hints complete +- [ ] Docstrings written + +### Testing +- [ ] Unit tests written +- [ ] Integration tests written +- [ ] Mock tests for offline testing +- [ ] Test coverage >90% +- [ ] All tests passing + +### Documentation +- [ ] API documentation updated +- [ ] Usage examples added +- [ ] README updated if needed +- [ ] CHANGELOG entry added + +### Code Quality +- [ ] Code formatted (make format) +- [ ] Linting passing (make lint) +- [ ] Type checking passing (make type-check) +- [ ] Pre-commit hooks passing + +## Acceptance Criteria + +- [ ] +- [ ] +- [ ] + +## API Endpoints Covered + +- `GET /v2/...` +- `POST /v2/...` + +## Related Issues/PRs + +- + +## Notes + diff --git a/.github/ISSUE_TEMPLATE/v3_progress.md b/.github/ISSUE_TEMPLATE/v3_progress.md new file mode 100644 index 0000000..fcd031d --- /dev/null +++ b/.github/ISSUE_TEMPLATE/v3_progress.md @@ -0,0 +1,68 @@ +--- +name: v3.0.0 Progress Update +about: Weekly progress update for v3.0.0 development +title: '[v3.0.0 Progress] Week of ' +labels: 'v3.0.0, progress' +assignees: '' + +--- + +## Week of [DATE] + +## Completed This Week + +### Features Completed +- [ ] Feature name (PR #) +- [ ] Feature name (PR #) + +### Tests Added +- Total new tests: +- Current coverage: + +### Documentation Updates +- [ ] Updated feature docs +- [ ] Added examples + +## In Progress + +### Currently Working On +- Feature: [name] - [% complete] +- Feature: [name] - [% complete] + +### Blockers + +- + +## Next Week's Plan + +### Features to Start +- [ ] Feature name +- [ ] Feature name + +### Features to Complete +- [ ] Feature name +- [ ] Feature name + +## Overall Progress + +### Phase Status +- **Phase 1**: [0]% Complete (Critical Features) +- **Phase 2**: [0]% Complete (Important Enhancements) +- **Phase 3**: [0]% Complete (Performance & Quality) +- **Phase 4**: [0]% Complete (Advanced Features) + +### Metrics +- Total API Coverage: [X]% +- Test Coverage: [X]% +- Documentation Complete: [X]% + +## Risk Assessment + +### New Risks Identified +- + +### Mitigation Actions +- + +## Notes + diff --git a/.github/workflows/claude-code-review.yml b/.github/workflows/claude-code-review.yml new file mode 100644 index 0000000..4caf96a --- /dev/null +++ b/.github/workflows/claude-code-review.yml @@ -0,0 +1,54 @@ +name: Claude Code Review + +on: + pull_request: + types: [opened, synchronize] + # Optional: Only run on specific file changes + # paths: + # - "src/**/*.ts" + # - "src/**/*.tsx" + # - "src/**/*.js" + # - "src/**/*.jsx" + +jobs: + claude-review: + # Optional: Filter by PR author + # if: | + # github.event.pull_request.user.login == 'external-contributor' || + # github.event.pull_request.user.login == 'new-developer' || + # github.event.pull_request.author_association == 'FIRST_TIME_CONTRIBUTOR' + + runs-on: ubuntu-latest + permissions: + contents: read + pull-requests: read + issues: read + id-token: write + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + fetch-depth: 1 + + - name: Run Claude Code Review + id: claude-review + uses: anthropics/claude-code-action@v1 + with: + claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }} + prompt: | + Please review this pull request and provide feedback on: + - Code quality and best practices + - Potential bugs or issues + - Performance considerations + - Security concerns + - Test coverage + + Use the repository's CLAUDE.md for guidance on style and conventions. Be constructive and helpful in your feedback. + + Use `gh pr comment` with your Bash tool to leave your review as a comment on the PR. + + # See https://github.com/anthropics/claude-code-action/blob/main/docs/usage.md + # or https://docs.anthropic.com/en/docs/claude-code/sdk#command-line for available options + claude_args: '--allowed-tools "Bash(gh issue view:*),Bash(gh search:*),Bash(gh issue list:*),Bash(gh pr comment:*),Bash(gh pr diff:*),Bash(gh pr view:*),Bash(gh pr list:*)"' + diff --git a/.github/workflows/claude.yml b/.github/workflows/claude.yml new file mode 100644 index 0000000..ae36c00 --- /dev/null +++ b/.github/workflows/claude.yml @@ -0,0 +1,50 @@ +name: Claude Code + +on: + issue_comment: + types: [created] + pull_request_review_comment: + types: [created] + issues: + types: [opened, assigned] + pull_request_review: + types: [submitted] + +jobs: + claude: + if: | + (github.event_name == 'issue_comment' && contains(github.event.comment.body, '@claude')) || + (github.event_name == 'pull_request_review_comment' && contains(github.event.comment.body, '@claude')) || + (github.event_name == 'pull_request_review' && contains(github.event.review.body, '@claude')) || + (github.event_name == 'issues' && (contains(github.event.issue.body, '@claude') || contains(github.event.issue.title, '@claude'))) + runs-on: ubuntu-latest + permissions: + contents: read + pull-requests: read + issues: read + id-token: write + actions: read # Required for Claude to read CI results on PRs + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + fetch-depth: 1 + + - name: Run Claude Code + id: claude + uses: anthropics/claude-code-action@v1 + with: + claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }} + + # This is an optional setting that allows Claude to read CI results on PRs + additional_permissions: | + actions: read + + # Optional: Give a custom prompt to Claude. If this is not specified, Claude will perform the instructions specified in the comment that tagged it. + # prompt: 'Update the pull request description to include a summary of changes.' + + # Optional: Add claude_args to customize behavior and configuration + # See https://github.com/anthropics/claude-code-action/blob/main/docs/usage.md + # or https://docs.anthropic.com/en/docs/claude-code/sdk#command-line for available options + # claude_args: '--model claude-opus-4-1-20250805 --allowed-tools Bash(gh pr:*)' + diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..4eda884 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,114 @@ +# Changelog + +All notable changes to py-alpaca-api will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [3.0.0] - Unreleased + +### Overview +Major release adding complete Alpaca Stock API coverage, performance improvements, and real-time data support. + +### Added +- πŸ“‹ Comprehensive development plan (DEVELOPMENT_PLAN.md) +- πŸ—οΈ New v3.0.0 branch structure for organized development + +### Planned Features (In Development) +#### Phase 1: Critical Missing Features +- [ ] Corporate Actions API - Track dividends, splits, mergers +- [ ] Trade Data Support - Access to individual trade data +- [ ] Market Snapshots - Current market overview for symbols + +#### Phase 2: Important Enhancements +- [ ] Account Configuration Management +- [ ] Enhanced Order Management (replace, extended hours) +- [ ] Market Metadata (condition codes, exchange codes) + +#### Phase 3: Performance & Quality +- [ ] Batch Operations for multiple symbols +- [ ] Feed Management System (IEX/SIP/OTC) +- [ ] Caching System with configurable TTL + +#### Phase 4: Advanced Features +- [ ] WebSocket Streaming Support +- [ ] Async/Await Implementation + +### Changed +- Restructured project for v3.0.0 development + +### Deprecated +- None + +### Removed +- None + +### Fixed +- None + +### Security +- None + +## [2.2.0] - 2024-12-15 + +### Added +- Stock analysis tools with ML predictions +- Market screener for gainers/losers +- News aggregation from multiple sources +- Sentiment analysis for stocks +- Prophet integration for price forecasting + +### Changed +- Improved error handling across all modules +- Enhanced DataFrame operations +- Better type safety with mypy strict mode + +### Fixed +- Yahoo Finance news fetching reliability +- DataFrame type preservation issues +- Prophet seasonality parameter handling + +## [2.1.0] - 2024-11-01 + +### Added +- Watchlist management functionality +- Portfolio history tracking +- Market calendar support +- Extended order types (bracket, trailing stop) + +### Changed +- Improved pagination for large datasets +- Better rate limit handling + +### Fixed +- Order validation for fractional shares +- Timezone handling in market hours + +## [2.0.0] - 2024-09-15 + +### Added +- Complete rewrite with modular architecture +- Full type hints and mypy support +- Comprehensive test suite (109+ tests) +- Separate trading and stock modules + +### Changed +- Breaking: New API structure with PyAlpacaAPI class +- Breaking: All methods now return typed dataclasses +- Improved error handling with custom exceptions + +### Removed +- Legacy API methods +- Deprecated authentication methods + +## [1.0.0] - 2024-06-01 + +### Added +- Initial release +- Basic trading operations +- Market data retrieval +- Account management + +--- + +*For detailed migration guides between versions, see [MIGRATION.md](MIGRATION.md)* diff --git a/DEVELOPMENT_PLAN.md b/DEVELOPMENT_PLAN.md new file mode 100644 index 0000000..738a47e --- /dev/null +++ b/DEVELOPMENT_PLAN.md @@ -0,0 +1,246 @@ +# py-alpaca-api Development Plan + +## πŸ“‹ Overview + +This document outlines the future development plan for py-alpaca-api, focusing on advanced features and continuous improvements. + +**Current Version**: 3.0.0 (Released) +**Next Version**: 3.1.0 (WebSocket Streaming) +**Future Version**: 3.2.0 (Async Support) + +## 🎯 Completed in v3.0.0 + +### βœ… Phase 1: Critical Missing Features +- Corporate Actions API +- Trade Data Support +- Market Snapshots + +### βœ… Phase 2: Important Enhancements +- Account Configuration +- Market Metadata +- Enhanced Order Management + +### βœ… Phase 3: Performance & Quality +- Batch Operations for multi-symbol data +- Feed Management System with automatic fallback +- Caching System with LRU and Redis support + +## πŸš€ Future Development + +### Version 3.1.0: WebSocket Streaming +**Target Release**: Q2 2025 +**Branch**: `feature/websocket-streaming` + +#### Goals +- Real-time market data streaming +- Reduced latency for live trading +- Efficient connection management +- Comprehensive error handling + +#### Tasks +- [ ] Create `streaming/` module structure +- [ ] Implement `StreamClient` class +- [ ] Add real-time quote streaming +- [ ] Add real-time trade streaming +- [ ] Add real-time bar aggregation +- [ ] Implement reconnection logic +- [ ] Add subscription management +- [ ] Add comprehensive tests (15+ test cases) +- [ ] Update documentation with examples + +#### Acceptance Criteria +- Stable WebSocket connection with automatic reconnection +- Efficient message parsing and handling +- Support for multiple symbol subscriptions +- Clean shutdown mechanism +- Comprehensive error handling and recovery + +### Version 3.2.0: Async Support +**Target Release**: Q3 2025 +**Branch**: `feature/async-support` + +#### Goals +- Full async/await support for all API methods +- Improved performance for concurrent operations +- Better resource utilization +- Backwards compatibility maintained + +#### Tasks +- [ ] Create `AsyncPyAlpacaAPI` class +- [ ] Implement async versions of all methods +- [ ] Add connection pooling with aiohttp +- [ ] Implement async rate limiting +- [ ] Add async cache support +- [ ] Create async streaming client +- [ ] Add comprehensive tests (20+ test cases) +- [ ] Update documentation with async examples + +#### Acceptance Criteria +- All methods have async equivalents +- Proper connection pooling and reuse +- Efficient concurrent execution +- Backwards compatible (sync API still works) +- Performance improvements documented + +## 🌳 Branching Strategy + +``` +main + └── v3.1.0 (for WebSocket features) + └── feature/websocket-streaming + └── v3.2.0 (for Async support) + └── feature/async-support +``` + +### Workflow +1. Create version branch from `main` +2. Create feature branches from version branch +3. Implement features with tests +4. Create PR to merge into version branch +5. Code review and testing +6. When complete, PR from version branch to `main` + +## πŸ“Š Roadmap + +| Version | Features | Status | Target Date | +|---------|----------|---------|-------------| +| 3.0.0 | Core API Coverage, Performance, Caching | βœ… Released | January 2025 | +| 3.1.0 | WebSocket Streaming | ⬜ Planned | Q2 2025 | +| 3.2.0 | Async Support | ⬜ Planned | Q3 2025 | +| 3.3.0 | Advanced Analytics | ⬜ Future | Q4 2025 | +| 4.0.0 | Options Trading Support | ⬜ Future | 2026 | + +## πŸ§ͺ Testing Strategy + +### Requirements +- Minimum 90% code coverage for new features +- All public methods must have tests +- Integration tests for API endpoints +- Mock tests for development without API keys +- Performance benchmarks for async operations + +### Test Categories +1. **Unit Tests**: Individual function testing +2. **Integration Tests**: API endpoint testing +3. **Performance Tests**: Load and efficiency testing +4. **Mock Tests**: Testing without live API +5. **Regression Tests**: Ensure backward compatibility + +## πŸ“ Documentation Requirements + +### For Each Feature +1. **API Documentation**: Comprehensive docstrings +2. **Usage Examples**: Practical code examples +3. **Migration Guide**: For any breaking changes +4. **Performance Guide**: For optimization tips +5. **Troubleshooting**: Common issues and solutions + +## πŸš€ Release Process + +### Version Strategy +- **x.x.0-alpha.x**: Early development releases +- **x.x.0-beta.x**: Feature complete, testing phase +- **x.x.0-rc.x**: Release candidates +- **x.x.0**: Stable release + +### Release Checklist +- [ ] All tests passing +- [ ] Documentation complete +- [ ] CHANGELOG updated +- [ ] Migration guide written (if needed) +- [ ] Performance benchmarks documented +- [ ] Security audit completed +- [ ] Package version bumped +- [ ] GitHub release created +- [ ] PyPI package published + +## πŸ” Code Review Standards + +For each PR: +- [ ] Code follows project style guide +- [ ] All tests passing +- [ ] Test coverage β‰₯ 90% +- [ ] Documentation updated +- [ ] Type hints complete +- [ ] No breaking changes (or properly documented) +- [ ] Performance impact assessed +- [ ] Security implications reviewed + +## πŸ“Š Success Metrics + +### Technical Metrics +- API coverage: 100% of stock endpoints +- Test coverage: >90% +- Performance: <50ms average response time (async) +- WebSocket stability: >99.9% uptime +- Memory usage: <100MB for typical operations + +### User Metrics +- GitHub stars growth +- PyPI downloads increase +- Issue resolution time <48 hours +- Community engagement metrics + +## 🀝 Contributing + +### How to Contribute +1. Check the roadmap for planned features +2. Open an issue to discuss your contribution +3. Fork the repository +4. Create a feature branch +5. Implement with tests and documentation +6. Submit a PR with all checks passing + +### Contribution Guidelines +- Follow the existing code style +- Include comprehensive tests +- Update documentation +- Add examples where appropriate +- Ensure backward compatibility + +## 🚨 Known Challenges + +### WebSocket Implementation +- **Challenge**: Maintaining stable connections +- **Solution**: Implement robust reconnection logic with exponential backoff + +### Async Migration +- **Challenge**: Maintaining backward compatibility +- **Solution**: Separate async classes while keeping sync API intact + +### Performance at Scale +- **Challenge**: Handling thousands of concurrent connections +- **Solution**: Connection pooling and efficient resource management + +## πŸ“… Maintenance Schedule + +### Regular Tasks +- **Weekly**: Review and triage new issues +- **Monthly**: Update dependencies +- **Quarterly**: Performance audit +- **Yearly**: Major version planning + +## πŸ“Œ Resources + +- [Alpaca API Documentation](https://docs.alpaca.markets/reference) +- [Project Repository](https://github.com/TexasCoding/py-alpaca-api) +- [Issue Tracker](https://github.com/TexasCoding/py-alpaca-api/issues) +- [PyPI Package](https://pypi.org/project/py-alpaca-api/) +- [WebSocket API Docs](https://docs.alpaca.markets/docs/real-time-market-data) +- [Python Async Best Practices](https://docs.python.org/3/library/asyncio.html) + +## 🎯 Definition of Done + +A feature is considered complete when: +1. βœ… All code implemented and reviewed +2. βœ… All tests passing (>90% coverage) +3. βœ… Documentation complete +4. βœ… Performance benchmarks met +5. βœ… No critical bugs reported in testing +6. βœ… Migration guide provided (if needed) + +--- + +**Last Updated**: 2025-01-16 +**Document Version**: 2.0.0 +**Maintained By**: py-alpaca-api Development Team diff --git a/README.md b/README.md index bef0991..41fea8a 100644 --- a/README.md +++ b/README.md @@ -11,14 +11,26 @@ A modern Python wrapper for the Alpaca Trading API, providing easy access to tra ## ✨ Features +### Core Features - **πŸ” Complete Alpaca API Coverage**: Trading, market data, account management, and more - **πŸ“Š Stock Market Analysis**: Built-in screeners for gainers/losers, historical data analysis +- **πŸš€ Batch Operations**: Efficient multi-symbol data fetching with automatic batching (200+ symbols) - **πŸ€– ML-Powered Predictions**: Stock price predictions using Facebook Prophet - **πŸ“° Financial News Integration**: Real-time news from Yahoo Finance and Benzinga - **πŸ“ˆ Technical Analysis**: Stock recommendations and sentiment analysis - **🎯 Type Safety**: Full type annotations with mypy strict mode -- **πŸ§ͺ Battle-Tested**: 100+ tests with comprehensive coverage -- **⚑ Modern Python**: Async-ready, Python 3.10+ with latest best practices +- **πŸ§ͺ Battle-Tested**: 300+ tests with comprehensive coverage +- **⚑ Modern Python**: Python 3.10+ with latest best practices + +### New in v3.0.0 +- **πŸ“Έ Market Snapshots**: Get complete market snapshots with latest trade, quote, and bar data +- **βš™οΈ Account Configuration**: Manage PDT settings, trade confirmations, and margin configurations +- **πŸ“‹ Market Metadata**: Access condition codes, exchange information, and trading metadata +- **πŸ”„ Enhanced Orders**: Replace orders, client order IDs, and advanced order management +- **🎯 Smart Feed Management**: Automatic feed selection and fallback (SIP β†’ IEX β†’ OTC) +- **πŸ’Ύ Intelligent Caching**: Built-in caching system with configurable TTLs for optimal performance +- **🏒 Corporate Actions**: Track dividends, splits, mergers, and other corporate events +- **πŸ“Š Trade Data API**: Access historical and real-time trade data with pagination ## πŸ“¦ Installation @@ -99,16 +111,30 @@ api.trading.orders.cancel_all() ### Market Data & Analysis ```python -# Get historical stock data +# Get historical stock data for a single symbol history = api.stock.history.get( symbol="TSLA", start="2024-01-01", end="2024-12-31" ) -# Get real-time quote +# NEW: Get historical data for multiple symbols (batch operation) +symbols = ["AAPL", "GOOGL", "MSFT", "TSLA", "AMZN"] +multi_history = api.stock.history.get( + symbol=symbols, # Pass a list for batch operation + start="2024-01-01", + end="2024-12-31" +) +# Returns DataFrame with all symbols' data, automatically handles batching for 200+ symbols + +# Get real-time quote for a single symbol quote = api.stock.latest_quote.get("MSFT") -print(f"MSFT Price: ${quote.ask_price}") +print(f"MSFT Price: ${quote.ask}") + +# NEW: Get real-time quotes for multiple symbols (batch operation) +quotes = api.stock.latest_quote.get(["AAPL", "GOOGL", "MSFT"]) +for quote in quotes: + print(f"{quote.symbol}: ${quote.ask}") # Screen for top gainers gainers = api.stock.screener.gainers( @@ -201,6 +227,252 @@ api.trading.watchlists.add_assets_to_watchlist( watchlists = api.trading.watchlists.get_all_watchlists() ``` +### Corporate Actions + +```python +# Get dividend announcements +dividends = api.trading.corporate_actions.get_announcements( + since="2024-01-01", + until="2024-03-31", + ca_types=["dividend"], + symbol="AAPL" # Optional: filter by symbol +) + +for dividend in dividends: + print(f"{dividend.initiating_symbol}: ${dividend.cash_amount} on {dividend.payable_date}") + +# Get stock splits +splits = api.trading.corporate_actions.get_announcements( + since="2024-01-01", + until="2024-03-31", + ca_types=["split"] +) + +for split in splits: + print(f"{split.initiating_symbol}: {split.split_from}:{split.split_to} split") + +# Get mergers and acquisitions +mergers = api.trading.corporate_actions.get_announcements( + since="2024-01-01", + until="2024-03-31", + ca_types=["merger"] +) + +# Get specific announcement by ID +announcement = api.trading.corporate_actions.get_announcement_by_id("123456") +print(f"Corporate Action: {announcement.ca_type} for {announcement.initiating_symbol}") + +# Get all types of corporate actions +all_actions = api.trading.corporate_actions.get_announcements( + since="2024-01-01", + until="2024-03-31", + ca_types=["dividend", "split", "merger", "spinoff"], + date_type="ex_dividend" # Filter by specific date type +) +``` + +### Trade Data + +```python +# Get historical trades for a symbol +trades_response = api.stock.trades.get_trades( + symbol="AAPL", + start="2024-01-15T09:30:00Z", + end="2024-01-15T10:00:00Z", + limit=100 +) + +for trade in trades_response.trades: + print(f"Trade: {trade.size} shares @ ${trade.price} on {trade.exchange}") + +# Get latest trade for a symbol +latest_trade = api.stock.trades.get_latest_trade("MSFT") +print(f"Latest MSFT trade: ${latest_trade.price} x {latest_trade.size}") + +# Get trades for multiple symbols +multi_trades = api.stock.trades.get_trades_multi( + symbols=["AAPL", "MSFT", "GOOGL"], + start="2024-01-15T09:30:00Z", + end="2024-01-15T10:00:00Z", + limit=10 +) + +for symbol, trades_data in multi_trades.items(): + print(f"{symbol}: {len(trades_data.trades)} trades") + +# Get all trades with automatic pagination +all_trades = api.stock.trades.get_all_trades( + symbol="SPY", + start="2024-01-15T09:30:00Z", + end="2024-01-15T09:35:00Z" +) +print(f"Total SPY trades: {len(all_trades)}") + +# Use different data feeds (requires subscription) +sip_trades = api.stock.trades.get_trades( + symbol="AAPL", + start="2024-01-15T09:30:00Z", + end="2024-01-15T10:00:00Z", + feed="sip" # or "iex", "otc" +) +``` + +### Market Snapshots + +```python +# Get snapshot for a single symbol +snapshot = api.stock.snapshots.get_snapshot("AAPL") +print(f"Latest trade: ${snapshot.latest_trade.price}") +print(f"Latest quote: Bid ${snapshot.latest_quote.bid} / Ask ${snapshot.latest_quote.ask}") +print(f"Daily bar: Open ${snapshot.daily_bar.open} / Close ${snapshot.daily_bar.close}") +print(f"Previous daily: Open ${snapshot.prev_daily_bar.open} / Close ${snapshot.prev_daily_bar.close}") + +# Get snapshots for multiple symbols (efficient batch operation) +symbols = ["AAPL", "GOOGL", "MSFT", "TSLA", "NVDA"] +snapshots = api.stock.snapshots.get_snapshots(symbols) +for symbol, snapshot in snapshots.items(): + print(f"{symbol}: ${snapshot.latest_trade.price} ({snapshot.daily_bar.volume:,} volume)") + +# Get snapshots with specific feed +snapshots = api.stock.snapshots.get_snapshots( + symbols=["SPY", "QQQ"], + feed="iex" # or "sip", "otc" +) +``` + +### Account Configuration + +```python +# Get current account configuration +config = api.trading.account.get_configuration() +print(f"PDT Check: {config.pdt_check}") +print(f"Trade Confirm Email: {config.trade_confirm_email}") +print(f"Suspend Trade: {config.suspend_trade}") +print(f"No Shorting: {config.no_shorting}") + +# Update account configuration +updated_config = api.trading.account.update_configuration( + trade_confirm_email=True, + suspend_trade=False, + pdt_check="both", # "both", "entry", or "exit" + no_shorting=False +) +print("Account configuration updated successfully") +``` + +### Market Metadata + +```python +# Get condition codes for trades +condition_codes = api.stock.metadata.get_condition_codes(tape="A") +for code in condition_codes: + print(f"Code {code.code}: {code.description}") + +# Get exchange codes +exchanges = api.stock.metadata.get_exchange_codes() +for exchange in exchanges: + print(f"{exchange.code}: {exchange.name} ({exchange.type})") + +# Get all condition codes at once (cached for performance) +all_codes = api.stock.metadata.get_all_condition_codes() +print(f"Loaded {len(all_codes)} condition codes") + +# Lookup specific codes +code_info = api.stock.metadata.lookup_condition_code("R") +print(f"Code R means: {code_info.description}") +``` + +### Enhanced Order Management + +```python +# Place order with client order ID for tracking +order = api.trading.orders.market( + symbol="AAPL", + qty=1, + side="buy", + client_order_id="my-app-order-123" +) + +# Replace an existing order (modify price, quantity, etc.) +replaced_order = api.trading.orders.replace_order( + order_id=order.id, + qty=2, # Change quantity + limit_price=155.00 # Add/change limit price +) + +# Get order by client order ID (useful for tracking) +orders = api.trading.orders.get_all(status="open") +my_order = next((o for o in orders if o.client_order_id == "my-app-order-123"), None) + +# Advanced OCO/OTO orders +oco_order = api.trading.orders.limit( + symbol="TSLA", + qty=1, + side="buy", + limit_price=200.00, + order_class="oco", # One-Cancels-Other + take_profit={"limit_price": 250.00}, + stop_loss={"stop_price": 180.00} +) +``` + +### Smart Feed Management + +```python +# The library automatically manages feed selection based on your subscription +# No configuration needed - it automatically detects and falls back as needed + +# Manual feed configuration (optional) +from py_alpaca_api.http.feed_manager import FeedManager, FeedConfig, FeedType + +# Configure preferred feeds +feed_config = FeedConfig( + preferred_feed=FeedType.SIP, # Try SIP first + fallback_feeds=[FeedType.IEX], # Fall back to IEX if needed + auto_fallback=True # Automatically handle permission errors +) + +# The feed manager automatically: +# - Detects your subscription level (Basic/Unlimited/Business) +# - Falls back to available feeds on permission errors +# - Caches failed feeds to avoid repeated attempts +# - Provides clear logging for debugging +``` + +### Intelligent Caching System + +```python +# Caching is built-in and automatic for improved performance +# Configure caching (optional - sensible defaults are provided) +from py_alpaca_api.cache import CacheManager, CacheConfig + +# Custom cache configuration +cache_config = CacheConfig( + max_size=1000, # Maximum items in cache + default_ttl=300, # Default time-to-live in seconds + data_ttls={ + "market_hours": 86400, # 1 day + "assets": 3600, # 1 hour + "quotes": 1, # 1 second + "positions": 10, # 10 seconds + } +) + +# Cache manager automatically: +# - Caches frequently accessed data +# - Reduces API calls and improves response times +# - Manages memory efficiently with LRU eviction +# - Supports optional Redis backend for distributed caching + +# Use the @cached decorator for custom caching +cache_manager = CacheManager(cache_config) + +@cache_manager.cached("custom_data", ttl=600) +def expensive_calculation(symbol: str): + # This result will be cached for 10 minutes + return complex_analysis(symbol) +``` + ### Advanced Order Types ```python @@ -300,27 +572,36 @@ make lint ``` py-alpaca-api/ β”œβ”€β”€ src/py_alpaca_api/ -β”‚ β”œβ”€β”€ __init__.py # Main API client -β”‚ β”œβ”€β”€ exceptions.py # Custom exceptions -β”‚ β”œβ”€β”€ trading/ # Trading operations -β”‚ β”‚ β”œβ”€β”€ account.py # Account management -β”‚ β”‚ β”œβ”€β”€ orders.py # Order management -β”‚ β”‚ β”œβ”€β”€ positions.py # Position tracking -β”‚ β”‚ β”œβ”€β”€ watchlists.py # Watchlist operations -β”‚ β”‚ β”œβ”€β”€ market.py # Market data -β”‚ β”‚ β”œβ”€β”€ news.py # Financial news -β”‚ β”‚ └── recommendations.py # Stock analysis -β”‚ β”œβ”€β”€ stock/ # Stock market data -β”‚ β”‚ β”œβ”€β”€ assets.py # Asset information -β”‚ β”‚ β”œβ”€β”€ history.py # Historical data -β”‚ β”‚ β”œβ”€β”€ screener.py # Stock screening -β”‚ β”‚ β”œβ”€β”€ predictor.py # ML predictions -β”‚ β”‚ └── latest_quote.py # Real-time quotes -β”‚ β”œβ”€β”€ models/ # Data models -β”‚ └── http/ # HTTP client -β”œβ”€β”€ tests/ # Test suite -β”œβ”€β”€ docs/ # Documentation -└── pyproject.toml # Project configuration +β”‚ β”œβ”€β”€ __init__.py # Main API client +β”‚ β”œβ”€β”€ exceptions.py # Custom exceptions +β”‚ β”œβ”€β”€ trading/ # Trading operations +β”‚ β”‚ β”œβ”€β”€ account.py # Account management & configuration +β”‚ β”‚ β”œβ”€β”€ orders.py # Order management (enhanced) +β”‚ β”‚ β”œβ”€β”€ positions.py # Position tracking +β”‚ β”‚ β”œβ”€β”€ watchlists.py # Watchlist operations +β”‚ β”‚ β”œβ”€β”€ market.py # Market hours & calendar +β”‚ β”‚ β”œβ”€β”€ news.py # Financial news +β”‚ β”‚ β”œβ”€β”€ recommendations.py # Stock analysis +β”‚ β”‚ └── corporate_actions.py # Corporate events (v3.0.0) +β”‚ β”œβ”€β”€ stock/ # Stock market data +β”‚ β”‚ β”œβ”€β”€ assets.py # Asset information +β”‚ β”‚ β”œβ”€β”€ history.py # Historical data (batch support) +β”‚ β”‚ β”œβ”€β”€ screener.py # Stock screening +β”‚ β”‚ β”œβ”€β”€ predictor.py # ML predictions +β”‚ β”‚ β”œβ”€β”€ latest_quote.py # Real-time quotes (batch support) +β”‚ β”‚ β”œβ”€β”€ trades.py # Trade data API (v3.0.0) +β”‚ β”‚ β”œβ”€β”€ snapshots.py # Market snapshots (v3.0.0) +β”‚ β”‚ └── metadata.py # Market metadata (v3.0.0) +β”‚ β”œβ”€β”€ models/ # Data models +β”‚ β”œβ”€β”€ cache/ # Caching system (v3.0.0) +β”‚ β”‚ β”œβ”€β”€ cache_manager.py # Cache management +β”‚ β”‚ └── cache_config.py # Cache configuration +β”‚ └── http/ # HTTP client +β”‚ β”œβ”€β”€ requests.py # Request handling +β”‚ └── feed_manager.py # Feed management (v3.0.0) +β”œβ”€β”€ tests/ # Test suite (300+ tests) +β”œβ”€β”€ docs/ # Documentation +└── pyproject.toml # Project configuration ``` ## πŸ“– Documentation @@ -379,13 +660,34 @@ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file ## πŸ—ΊοΈ Roadmap +### v3.0.0 (Current Release) +- βœ… Complete Alpaca Stock API coverage +- βœ… Market Snapshots API +- βœ… Account Configuration API +- βœ… Market Metadata API +- βœ… Enhanced Order Management +- βœ… Corporate Actions API +- βœ… Trade Data API +- βœ… Smart Feed Management System +- βœ… Intelligent Caching System +- βœ… Batch Operations for all data endpoints + +### v3.1.0 (Planned) - [ ] WebSocket support for real-time data streaming +- [ ] Live market data subscriptions +- [ ] Real-time order and trade updates + +### v3.2.0 (Planned) +- [ ] Full async/await support +- [ ] Concurrent API operations +- [ ] Async context managers + +### Future Releases - [ ] Options trading support - [ ] Crypto trading integration - [ ] Advanced portfolio analytics - [ ] Backtesting framework - [ ] Strategy automation tools -- [ ] Mobile app integration ## ⚠️ Disclaimer diff --git a/pyproject.toml b/pyproject.toml index 7d00e85..b99d9f6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "py-alpaca-api" -version = "2.2.0" +version = "3.0.0" description = "Add your description here" readme = "README.md" requires-python = ">=3.10" diff --git a/src/py_alpaca_api/cache/__init__.py b/src/py_alpaca_api/cache/__init__.py new file mode 100644 index 0000000..f828bcd --- /dev/null +++ b/src/py_alpaca_api/cache/__init__.py @@ -0,0 +1,9 @@ +"""Cache module for py-alpaca-api. + +This module provides caching functionality to improve performance and reduce API calls. +""" + +from .cache_config import CacheConfig, CacheType +from .cache_manager import CacheManager + +__all__ = ["CacheManager", "CacheConfig", "CacheType"] diff --git a/src/py_alpaca_api/cache/cache_config.py b/src/py_alpaca_api/cache/cache_config.py new file mode 100644 index 0000000..1db2dc7 --- /dev/null +++ b/src/py_alpaca_api/cache/cache_config.py @@ -0,0 +1,68 @@ +"""Cache configuration for py-alpaca-api.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from enum import Enum + + +class CacheType(Enum): + """Types of cache backends supported.""" + + MEMORY = "memory" + REDIS = "redis" + DISABLED = "disabled" + + +@dataclass +class CacheConfig: + """Configuration for cache system. + + Attributes: + cache_type: Type of cache backend to use + max_size: Maximum number of items in memory cache + default_ttl: Default time-to-live in seconds + data_ttls: TTL overrides per data type + redis_host: Redis host (if using Redis) + redis_port: Redis port (if using Redis) + redis_db: Redis database number (if using Redis) + redis_password: Redis password (if using Redis) + enabled: Whether caching is enabled + """ + + cache_type: CacheType = CacheType.MEMORY + max_size: int = 1000 + default_ttl: int = 300 # 5 minutes default + data_ttls: dict[str, int] = field( + default_factory=lambda: { + "market_hours": 86400, # 1 day + "calendar": 86400, # 1 day + "assets": 3600, # 1 hour + "account": 60, # 1 minute + "positions": 10, # 10 seconds + "orders": 5, # 5 seconds + "quotes": 1, # 1 second + "bars": 60, # 1 minute + "trades": 60, # 1 minute + "news": 300, # 5 minutes + "watchlists": 300, # 5 minutes + "snapshots": 1, # 1 second + "metadata": 86400, # 1 day (condition codes, exchanges) + } + ) + redis_host: str = "localhost" + redis_port: int = 6379 + redis_db: int = 0 + redis_password: str | None = None + enabled: bool = True + + def get_ttl(self, data_type: str) -> int: + """Get TTL for a specific data type. + + Args: + data_type: Type of data to get TTL for + + Returns: + TTL in seconds + """ + return self.data_ttls.get(data_type, self.default_ttl) diff --git a/src/py_alpaca_api/cache/cache_manager.py b/src/py_alpaca_api/cache/cache_manager.py new file mode 100644 index 0000000..fca1d46 --- /dev/null +++ b/src/py_alpaca_api/cache/cache_manager.py @@ -0,0 +1,462 @@ +"""Cache manager for py-alpaca-api.""" + +from __future__ import annotations + +import hashlib +import json +import logging +import time +from collections import OrderedDict +from collections.abc import Callable +from dataclasses import asdict, is_dataclass +from typing import Any + +from py_alpaca_api.cache.cache_config import CacheConfig, CacheType + +logger = logging.getLogger(__name__) + + +class LRUCache: + """Least Recently Used (LRU) cache implementation.""" + + def __init__(self, max_size: int = 1000): + """Initialize LRU cache. + + Args: + max_size: Maximum number of items to store + """ + self.max_size = max_size + self.cache: OrderedDict[str, tuple[Any, float]] = OrderedDict() + + def get(self, key: str) -> Any | None: + """Get item from cache. + + Args: + key: Cache key + + Returns: + Cached value or None if not found/expired + """ + if key not in self.cache: + return None + + value, expiry = self.cache[key] + + if time.time() > expiry: + del self.cache[key] + return None + + # Move to end to mark as recently used + self.cache.move_to_end(key) + return value + + def set(self, key: str, value: Any, ttl: int) -> None: + """Set item in cache. + + Args: + key: Cache key + value: Value to cache + ttl: Time-to-live in seconds + """ + expiry = time.time() + ttl + self.cache[key] = (value, expiry) + self.cache.move_to_end(key) + + # Enforce size limit + while len(self.cache) > self.max_size: + self.cache.popitem(last=False) + + def delete(self, key: str) -> bool: + """Delete item from cache. + + Args: + key: Cache key + + Returns: + True if deleted, False if not found + """ + if key in self.cache: + del self.cache[key] + return True + return False + + def clear(self) -> None: + """Clear all items from cache.""" + self.cache.clear() + + def size(self) -> int: + """Get current cache size. + + Returns: + Number of items in cache + """ + return len(self.cache) + + def cleanup_expired(self) -> int: + """Remove expired items from cache. + + Returns: + Number of items removed + """ + current_time = time.time() + expired_keys = [ + key for key, (_, expiry) in self.cache.items() if current_time > expiry + ] + + for key in expired_keys: + del self.cache[key] + + return len(expired_keys) + + +class RedisCache: + """Redis cache implementation.""" + + def __init__(self, config: CacheConfig): + """Initialize Redis cache. + + Args: + config: Cache configuration + """ + self.config = config + self._client = None + + def _get_client(self): + """Get or create Redis client.""" + if self._client is None: + try: + import redis + + self._client = redis.Redis( + host=self.config.redis_host, + port=self.config.redis_port, + db=self.config.redis_db, + password=self.config.redis_password, + decode_responses=True, + ) + # Test connection + self._client.ping() + logger.info("Redis cache connected successfully") + except ImportError: + logger.warning("Redis not installed, falling back to memory cache") + raise + except Exception: + logger.exception("Failed to connect to Redis") + raise + + return self._client + + def get(self, key: str) -> Any | None: + """Get item from cache. + + Args: + key: Cache key + + Returns: + Cached value or None if not found + """ + try: + client = self._get_client() + value = client.get(key) + if value: + return json.loads(value) + return None + except Exception as e: + logger.warning(f"Redis get failed: {e}") + return None + + def set(self, key: str, value: Any, ttl: int) -> None: + """Set item in cache. + + Args: + key: Cache key + value: Value to cache + ttl: Time-to-live in seconds + """ + try: + client = self._get_client() + json_value = json.dumps(value, default=str) + client.setex(key, ttl, json_value) + except Exception as e: + logger.warning(f"Redis set failed: {e}") + + def delete(self, key: str) -> bool: + """Delete item from cache. + + Args: + key: Cache key + + Returns: + True if deleted, False if not found + """ + try: + client = self._get_client() + return bool(client.delete(key)) + except Exception as e: + logger.warning(f"Redis delete failed: {e}") + return False + + def clear(self) -> None: + """Clear all items from cache.""" + try: + client = self._get_client() + client.flushdb() + except Exception as e: + logger.warning(f"Redis clear failed: {e}") + + def size(self) -> int: + """Get current cache size. + + Returns: + Number of items in cache + """ + try: + client = self._get_client() + return client.dbsize() + except Exception as e: + logger.warning(f"Redis size failed: {e}") + return 0 + + +class CacheManager: + """Manages caching for py-alpaca-api.""" + + def __init__(self, config: CacheConfig | None = None): + """Initialize cache manager. + + Args: + config: Cache configuration. If None, uses defaults. + """ + self.config = config or CacheConfig() + self._cache = self._create_cache() + self._hit_count = 0 + self._miss_count = 0 + + def _create_cache(self) -> LRUCache | RedisCache: + """Create appropriate cache backend. + + Returns: + Cache implementation + """ + if not self.config.enabled or self.config.cache_type == CacheType.DISABLED: + logger.info("Caching disabled") + return LRUCache(max_size=0) # Dummy cache that stores nothing + + if self.config.cache_type == CacheType.REDIS: + try: + cache = RedisCache(self.config) + # Test the connection + cache._get_client() + return cache + except Exception as e: + logger.warning( + f"Failed to create Redis cache: {e}, falling back to memory cache" + ) + return LRUCache(self.config.max_size) + + return LRUCache(self.config.max_size) + + def generate_key(self, prefix: str, **kwargs) -> str: + """Generate cache key from prefix and parameters. + + Args: + prefix: Key prefix (e.g., "bars", "quotes") + **kwargs: Parameters to include in key + + Returns: + Cache key + """ + # Sort kwargs for consistent key generation + sorted_params = sorted(kwargs.items()) + param_str = json.dumps(sorted_params, sort_keys=True, default=str) + + # Create hash for long keys + if len(param_str) > 100: + param_hash = hashlib.md5(param_str.encode()).hexdigest() + return f"{prefix}:{param_hash}" + + return f"{prefix}:{param_str}" + + def get(self, key: str, data_type: str | None = None) -> Any | None: # noqa: ARG002 + """Get item from cache. + + Args: + key: Cache key + data_type: Optional data type for metrics + + Returns: + Cached value or None if not found + """ + if not self.config.enabled: + return None + + value = self._cache.get(key) + + if value is not None: + self._hit_count += 1 + logger.debug(f"Cache hit for {key}") + else: + self._miss_count += 1 + logger.debug(f"Cache miss for {key}") + + return value + + def set(self, key: str, value: Any, data_type: str, ttl: int | None = None) -> None: + """Set item in cache. + + Args: + key: Cache key + value: Value to cache + data_type: Type of data (for TTL lookup) + ttl: Optional TTL override in seconds + """ + if not self.config.enabled: + return + + if ttl is None: + ttl = self.config.get_ttl(data_type) + + # Convert dataclass to dict for JSON serialization + if is_dataclass(value): + value = asdict(value) + elif isinstance(value, list) and value and is_dataclass(value[0]): + value = [asdict(item) for item in value] + + self._cache.set(key, value, ttl) + logger.debug(f"Cached {key} with TTL {ttl}s") + + def delete(self, key: str) -> bool: + """Delete item from cache. + + Args: + key: Cache key + + Returns: + True if deleted, False if not found + """ + if not self.config.enabled: + return False + + return self._cache.delete(key) + + def clear(self, prefix: str | None = None) -> int: + """Clear cache items. + + Args: + prefix: Optional prefix to clear only specific items + + Returns: + Number of items cleared + """ + if not self.config.enabled: + return 0 + + if prefix is None: + # Clear everything + size_before = self._cache.size() + self._cache.clear() + logger.info(f"Cleared entire cache ({size_before} items)") + return size_before + + # Clear items with specific prefix + if isinstance(self._cache, LRUCache): + keys_to_delete = [ + key for key in self._cache.cache if key.startswith(f"{prefix}:") + ] + for key in keys_to_delete: + self._cache.delete(key) + + logger.info(f"Cleared {len(keys_to_delete)} items with prefix '{prefix}'") + return len(keys_to_delete) + + # For Redis, we'd need to scan keys (expensive operation) + logger.warning("Prefix-based clearing not fully supported for Redis cache") + return 0 + + def invalidate_pattern(self, pattern: str) -> int: + """Invalidate cache items matching a pattern. + + Args: + pattern: Pattern to match (e.g., "bars:*AAPL*") + + Returns: + Number of items invalidated + """ + if not self.config.enabled: + return 0 + + count = 0 + if isinstance(self._cache, LRUCache): + import fnmatch + + keys_to_delete = [ + key for key in self._cache.cache if fnmatch.fnmatch(key, pattern) + ] + for key in keys_to_delete: + self._cache.delete(key) + count += 1 + + logger.info(f"Invalidated {count} items matching pattern '{pattern}'") + return count + + def get_stats(self) -> dict[str, Any]: + """Get cache statistics. + + Returns: + Dictionary with cache stats + """ + hit_rate = 0.0 + total = self._hit_count + self._miss_count + if total > 0: + hit_rate = self._hit_count / total + + return { + "enabled": self.config.enabled, + "type": self.config.cache_type.value, + "size": self._cache.size() if self.config.enabled else 0, + "max_size": self.config.max_size, + "hit_count": self._hit_count, + "miss_count": self._miss_count, + "hit_rate": hit_rate, + "total_requests": total, + } + + def reset_stats(self) -> None: + """Reset cache statistics.""" + self._hit_count = 0 + self._miss_count = 0 + logger.debug("Cache statistics reset") + + def cached(self, data_type: str, ttl: int | None = None) -> Callable: + """Decorator for caching function results. + + Args: + data_type: Type of data being cached + ttl: Optional TTL override + + Returns: + Decorator function + """ + + def decorator(func: Callable) -> Callable: + def wrapper(*args, **kwargs): + # Generate cache key from function name and arguments + cache_key = self.generate_key( + f"{func.__module__}.{func.__name__}", + args=str(args), + kwargs=str(kwargs), + ) + + # Try to get from cache + cached_value = self.get(cache_key, data_type) + if cached_value is not None: + return cached_value + + # Call function and cache result + result = func(*args, **kwargs) + self.set(cache_key, result, data_type, ttl) + return result + + return wrapper + + return decorator diff --git a/src/py_alpaca_api/http/feed_manager.py b/src/py_alpaca_api/http/feed_manager.py new file mode 100644 index 0000000..710acd7 --- /dev/null +++ b/src/py_alpaca_api/http/feed_manager.py @@ -0,0 +1,298 @@ +from __future__ import annotations + +import logging +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, ClassVar + +from py_alpaca_api.exceptions import APIRequestError, ValidationError + +logger = logging.getLogger(__name__) + + +class FeedType(Enum): + """Available data feed types.""" + + SIP = "sip" + IEX = "iex" + OTC = "otc" + + @classmethod + def from_string(cls, value: str) -> FeedType: + """Create FeedType from string value.""" + try: + return cls(value.lower()) + except ValueError as e: + raise ValidationError( + f"Invalid feed type: {value}. Must be one of {[f.value for f in cls]}" + ) from e + + +class SubscriptionLevel(Enum): + """User subscription levels.""" + + BASIC = "basic" + UNLIMITED = "unlimited" + BUSINESS = "business" + + @classmethod + def from_error(cls, error_message: str) -> SubscriptionLevel | None: + """Detect subscription level from error message.""" + error_lower = error_message.lower() + + if "subscription" in error_lower: + if "unlimited" in error_lower or "business" in error_lower: + return cls.UNLIMITED + return cls.BASIC + return None + + +@dataclass +class FeedConfig: + """Configuration for feed management.""" + + preferred_feed: FeedType = FeedType.SIP + fallback_feeds: list[FeedType] = field(default_factory=lambda: [FeedType.IEX]) + auto_fallback: bool = True + subscription_level: SubscriptionLevel | None = None + endpoint_feeds: dict[str, FeedType] = field(default_factory=dict) + + def get_feed_for_endpoint(self, endpoint: str) -> FeedType: + """Get the configured feed for a specific endpoint.""" + return self.endpoint_feeds.get(endpoint, self.preferred_feed) + + +class FeedManager: + """Manages data feed selection and fallback logic.""" + + # Endpoints that support feed parameter + FEED_SUPPORTED_ENDPOINTS: ClassVar[set[str]] = { + "bars", + "quotes", + "trades", + "snapshots", + "latest/bars", + "latest/quotes", + "latest/trades", + } + + # Feed availability by subscription level + SUBSCRIPTION_FEEDS: ClassVar[dict[SubscriptionLevel, list[FeedType]]] = { + SubscriptionLevel.BASIC: [FeedType.IEX], + SubscriptionLevel.UNLIMITED: [FeedType.SIP, FeedType.IEX, FeedType.OTC], + SubscriptionLevel.BUSINESS: [FeedType.SIP, FeedType.IEX, FeedType.OTC], + } + + def __init__(self, config: FeedConfig | None = None): + """Initialize the feed manager. + + Args: + config: Feed configuration. If None, uses defaults. + """ + self.config = config or FeedConfig() + self._failed_feeds: dict[str, set[FeedType]] = {} + self._detected_subscription_level: SubscriptionLevel | None = None + + def get_feed(self, endpoint: str, symbol: str | None = None) -> str | None: + """Get the appropriate feed for an endpoint. + + Args: + endpoint: The API endpoint being called + symbol: Optional symbol for endpoint-specific logic + + Returns: + Feed parameter value or None if endpoint doesn't support feeds + """ + if not self._supports_feed(endpoint): + return None + + feed = self.config.get_feed_for_endpoint(endpoint) + + # Check if this feed has previously failed + endpoint_key = f"{endpoint}:{symbol}" if symbol else endpoint + if ( + endpoint_key in self._failed_feeds + and feed in self._failed_feeds[endpoint_key] + ): + # Try to use fallback + for fallback in self.config.fallback_feeds: + if fallback not in self._failed_feeds.get(endpoint_key, set()): + logger.info(f"Using fallback feed {fallback.value} for {endpoint}") + return fallback.value + + return feed.value + + def handle_feed_error( + self, + endpoint: str, + feed: str, + error: APIRequestError, + symbol: str | None = None, + ) -> str | None: + """Handle feed-related errors and return alternative feed if available. + + Args: + endpoint: The API endpoint that failed + feed: The feed that caused the error + error: The API error + symbol: Optional symbol for endpoint-specific tracking + + Returns: + Alternative feed to try, or None if no alternatives available + """ + if not self.config.auto_fallback: + return None + + # Try to detect subscription level from error + error_msg = str(error) + detected_level = SubscriptionLevel.from_error(error_msg) + if detected_level and not self._detected_subscription_level: + self._detected_subscription_level = detected_level + logger.info(f"Detected subscription level: {detected_level.value}") + + # Track failed feed + endpoint_key = f"{endpoint}:{symbol}" if symbol else endpoint + if endpoint_key not in self._failed_feeds: + self._failed_feeds[endpoint_key] = set() + + try: + feed_type = FeedType.from_string(feed) + self._failed_feeds[endpoint_key].add(feed_type) + logger.warning(f"Feed {feed} failed for {endpoint_key}: {error_msg}") + except ValidationError: + logger.exception(f"Invalid feed type in error handling: {feed}") + return None + + # Find alternative feed + for fallback in self.config.fallback_feeds: + if fallback not in self._failed_feeds[ + endpoint_key + ] and self._is_feed_available(fallback): + logger.info(f"Falling back to {fallback.value} feed for {endpoint_key}") + return fallback.value + + logger.error(f"No alternative feeds available for {endpoint_key}") + return None + + def detect_subscription_level(self, api_client: Any) -> SubscriptionLevel: + """Detect user's subscription level by testing API access. + + Args: + api_client: API client instance to test with + + Returns: + Detected subscription level + """ + # Try SIP feed first (requires Unlimited/Business) + try: + # Make a test request with SIP feed + test_endpoint = "latest/quotes" + test_params = {"symbols": "AAPL", "feed": FeedType.SIP.value} + + api_client._make_request( + "GET", f"/stocks/{test_endpoint}", params=test_params + ) + + # If successful, user has at least Unlimited + self._detected_subscription_level = SubscriptionLevel.UNLIMITED + logger.info("Detected Unlimited/Business subscription level") + + except APIRequestError as e: + # SIP failed, user likely has Basic subscription + if "subscription" in str(e).lower() or "unauthorized" in str(e).lower(): + self._detected_subscription_level = SubscriptionLevel.BASIC + logger.info("Detected Basic subscription level") + else: + # Unexpected error, default to Basic for safety + self._detected_subscription_level = SubscriptionLevel.BASIC + logger.warning( + f"Could not detect subscription level: {e}. Defaulting to Basic." + ) + + self.config.subscription_level = self._detected_subscription_level + return self._detected_subscription_level + + def validate_feed(self, endpoint: str, feed: str) -> bool: + """Validate if a feed is appropriate for an endpoint. + + Args: + endpoint: The API endpoint + feed: The feed to validate + + Returns: + True if feed is valid for endpoint + """ + if not self._supports_feed(endpoint): + return False + + try: + feed_type = FeedType.from_string(feed) + except ValidationError: + return False + + return self._is_feed_available(feed_type) + + def reset_failures(self, endpoint: str | None = None) -> None: + """Reset tracked feed failures. + + Args: + endpoint: Optional endpoint to reset. If None, resets all. + """ + if endpoint: + keys_to_remove = [ + k for k in self._failed_feeds if k.startswith(f"{endpoint}:") + ] + for key in keys_to_remove: + del self._failed_feeds[key] + if endpoint in self._failed_feeds: + del self._failed_feeds[endpoint] + else: + self._failed_feeds.clear() + + logger.info(f"Reset feed failures for {endpoint or 'all endpoints'}") + + def _supports_feed(self, endpoint: str) -> bool: + """Check if an endpoint supports feed parameter. + + Args: + endpoint: The API endpoint + + Returns: + True if endpoint supports feed parameter + """ + # Check if any supported endpoint pattern matches + return any(supported in endpoint for supported in self.FEED_SUPPORTED_ENDPOINTS) + + def _is_feed_available(self, feed: FeedType) -> bool: + """Check if a feed is available based on subscription level. + + Args: + feed: The feed to check + + Returns: + True if feed is available + """ + if not self._detected_subscription_level and not self.config.subscription_level: + # If we don't know subscription level, assume all feeds available + return True + + level = self._detected_subscription_level or self.config.subscription_level + if level is None: + return True + available_feeds = self.SUBSCRIPTION_FEEDS.get(level, []) + return feed in available_feeds + + def get_available_feeds(self) -> list[FeedType]: + """Get list of available feeds based on subscription level. + + Returns: + List of available feed types + """ + if not self._detected_subscription_level and not self.config.subscription_level: + # If unknown, return all feeds + return list(FeedType) + + level = self._detected_subscription_level or self.config.subscription_level + if level is None: + return list(FeedType) + return self.SUBSCRIPTION_FEEDS.get(level, [FeedType.IEX]) diff --git a/src/py_alpaca_api/models/account_config_model.py b/src/py_alpaca_api/models/account_config_model.py new file mode 100644 index 0000000..c6a8c0d --- /dev/null +++ b/src/py_alpaca_api/models/account_config_model.py @@ -0,0 +1,47 @@ +from dataclasses import dataclass + + +@dataclass +class AccountConfigModel: + """Model for account configuration settings. + + Attributes: + dtbp_check: Day trade buying power check setting ("entry", "exit", "both") + fractional_trading: Whether fractional trading is enabled + max_margin_multiplier: Maximum margin multiplier allowed ("1", "2", "4") + no_shorting: Whether short selling is disabled + pdt_check: Pattern day trader check setting ("entry", "exit", "both") + ptp_no_exception_entry: Whether PTP no exception entry is enabled + suspend_trade: Whether trading is suspended + trade_confirm_email: Trade confirmation email setting ("all", "none") + """ + + dtbp_check: str + fractional_trading: bool + max_margin_multiplier: str + no_shorting: bool + pdt_check: str + ptp_no_exception_entry: bool + suspend_trade: bool + trade_confirm_email: str + + +def account_config_class_from_dict(data: dict) -> AccountConfigModel: + """Create AccountConfigModel from API response dictionary. + + Args: + data: Dictionary containing account configuration data from API + + Returns: + AccountConfigModel instance + """ + return AccountConfigModel( + dtbp_check=data.get("dtbp_check", "entry"), + fractional_trading=data.get("fractional_trading", False), + max_margin_multiplier=data.get("max_margin_multiplier", "1"), + no_shorting=data.get("no_shorting", False), + pdt_check=data.get("pdt_check", "entry"), + ptp_no_exception_entry=data.get("ptp_no_exception_entry", False), + suspend_trade=data.get("suspend_trade", False), + trade_confirm_email=data.get("trade_confirm_email", "all"), + ) diff --git a/src/py_alpaca_api/models/corporate_action_model.py b/src/py_alpaca_api/models/corporate_action_model.py new file mode 100644 index 0000000..0068638 --- /dev/null +++ b/src/py_alpaca_api/models/corporate_action_model.py @@ -0,0 +1,135 @@ +from dataclasses import dataclass +from typing import Any + + +@dataclass +class CorporateActionModel: + """Base model for corporate action announcements.""" + + id: str + corporate_action_id: str + ca_type: str + ca_sub_type: str | None + initiating_symbol: str | None + initiating_original_cusip: str | None + target_symbol: str | None + target_original_cusip: str | None + declaration_date: str | None + ex_date: str | None + record_date: str | None + payable_date: str | None + cash: float | None + old_rate: float | None + new_rate: float | None + + +@dataclass +class DividendModel(CorporateActionModel): + """Model for dividend corporate actions.""" + + cash_amount: float | None + dividend_type: str | None + frequency: int | None + + +@dataclass +class SplitModel(CorporateActionModel): + """Model for stock split corporate actions.""" + + split_from: float | None + split_to: float | None + + +@dataclass +class MergerModel(CorporateActionModel): + """Model for merger corporate actions.""" + + acquirer_symbol: str | None + acquirer_cusip: str | None + cash_rate: float | None + stock_rate: float | None + + +@dataclass +class SpinoffModel(CorporateActionModel): + """Model for spinoff corporate actions.""" + + new_symbol: str | None + new_cusip: str | None + ratio: float | None + + +def corporate_action_class_from_dict(data: dict[str, Any]) -> CorporateActionModel: + """Create appropriate corporate action model from dictionary. + + Args: + data: Dictionary containing corporate action data + + Returns: + CorporateActionModel or one of its subclasses based on ca_type + """ + ca_type = data.get("ca_type", "").lower() + + # Extract common fields + base_fields = { + "id": data.get("id", ""), + "corporate_action_id": data.get("corporate_action_id", ""), + "ca_type": data.get("ca_type", ""), + "ca_sub_type": data.get("ca_sub_type"), + "initiating_symbol": data.get("initiating_symbol"), + "initiating_original_cusip": data.get("initiating_original_cusip"), + "target_symbol": data.get("target_symbol"), + "target_original_cusip": data.get("target_original_cusip"), + "declaration_date": data.get("declaration_date"), + "ex_date": data.get("ex_date"), + "record_date": data.get("record_date"), + "payable_date": data.get("payable_date"), + "cash": data.get("cash"), + "old_rate": data.get("old_rate"), + "new_rate": data.get("new_rate"), + } + + if ca_type == "dividend": + return DividendModel( + **base_fields, + cash_amount=data.get("cash_amount"), + dividend_type=data.get("dividend_type"), + frequency=data.get("frequency"), + ) + if ca_type == "split": + return SplitModel( + **base_fields, + split_from=data.get("split_from"), + split_to=data.get("split_to"), + ) + if ca_type == "merger": + return MergerModel( + **base_fields, + acquirer_symbol=data.get("acquirer_symbol"), + acquirer_cusip=data.get("acquirer_cusip"), + cash_rate=data.get("cash_rate"), + stock_rate=data.get("stock_rate"), + ) + if ca_type == "spinoff": + return SpinoffModel( + **base_fields, + new_symbol=data.get("new_symbol"), + new_cusip=data.get("new_cusip"), + ratio=data.get("ratio"), + ) + # Return base model for unknown types + return CorporateActionModel(**base_fields) + + +def extract_corporate_action_data(data: dict[str, Any]) -> dict[str, Any]: + """Extract and transform corporate action data from API response. + + Args: + data: Raw API response data + + Returns: + Transformed dictionary ready for model creation + """ + # This function can handle any data transformation needed + # between the API response and our model structure + return data diff --git a/src/py_alpaca_api/models/snapshot_model.py b/src/py_alpaca_api/models/snapshot_model.py new file mode 100644 index 0000000..fb33253 --- /dev/null +++ b/src/py_alpaca_api/models/snapshot_model.py @@ -0,0 +1,87 @@ +from dataclasses import dataclass + +import pendulum + +from py_alpaca_api.models.quote_model import QuoteModel, quote_class_from_dict +from py_alpaca_api.models.trade_model import TradeModel, trade_class_from_dict + + +@dataclass +class BarModel: + timestamp: str # Store as string for consistency with other models + open: float + high: float + low: float + close: float + volume: int + trade_count: int | None = None + vwap: float | None = None + + +@dataclass +class SnapshotModel: + symbol: str + latest_trade: TradeModel | None = None + latest_quote: QuoteModel | None = None + minute_bar: BarModel | None = None + daily_bar: BarModel | None = None + prev_daily_bar: BarModel | None = None + + +def bar_class_from_dict(data: dict) -> BarModel: + # Parse timestamp + timestamp_str = data.get("t", "") + if timestamp_str: + timestamp = pendulum.parse(timestamp_str, tz="America/New_York") + if isinstance(timestamp, pendulum.DateTime): + timestamp_str = timestamp.strftime("%Y-%m-%d %H:%M:%S") + else: + timestamp_str = str(timestamp) + + return BarModel( + timestamp=timestamp_str, + open=float(data.get("o", 0.0)), + high=float(data.get("h", 0.0)), + low=float(data.get("l", 0.0)), + close=float(data.get("c", 0.0)), + volume=int(data.get("v", 0)), + trade_count=int(data["n"]) if "n" in data and data["n"] is not None else None, + vwap=float(data["vw"]) if "vw" in data and data["vw"] is not None else None, + ) + + +def snapshot_class_from_dict(data: dict) -> SnapshotModel: + snapshot_data = {"symbol": data.get("symbol", "")} + + if data.get("latestTrade"): + trade_data = data["latestTrade"] + snapshot_data["latest_trade"] = trade_class_from_dict( + trade_data, data.get("symbol", "") + ) + + if data.get("latestQuote"): + quote_data = data["latestQuote"] + # Map API field names to model field names + quote_dict = { + "symbol": data.get("symbol", ""), + "timestamp": quote_data.get("t", ""), + "ask": quote_data.get("ap", 0.0), + "ask_size": quote_data.get("as", 0), + "bid": quote_data.get("bp", 0.0), + "bid_size": quote_data.get("bs", 0), + } + snapshot_data["latest_quote"] = quote_class_from_dict(quote_dict) + + if data.get("minuteBar"): + bar_data = data["minuteBar"] + snapshot_data["minute_bar"] = bar_class_from_dict(bar_data) + + if data.get("dailyBar"): + bar_data = data["dailyBar"] + snapshot_data["daily_bar"] = bar_class_from_dict(bar_data) + + if data.get("prevDailyBar"): + bar_data = data["prevDailyBar"] + snapshot_data["prev_daily_bar"] = bar_class_from_dict(bar_data) + + return SnapshotModel(**snapshot_data) diff --git a/src/py_alpaca_api/models/trade_model.py b/src/py_alpaca_api/models/trade_model.py new file mode 100644 index 0000000..9385092 --- /dev/null +++ b/src/py_alpaca_api/models/trade_model.py @@ -0,0 +1,77 @@ +from dataclasses import dataclass +from typing import Any + + +@dataclass +class TradeModel: + """Model for individual stock trade data.""" + + timestamp: str # RFC-3339 format timestamp + symbol: str + exchange: str + price: float + size: int + conditions: list[str] | None + id: int + tape: str + + +def trade_class_from_dict( + data: dict[str, Any], symbol: str | None = None +) -> TradeModel: + """Create TradeModel from API response dictionary. + + Args: + data: Dictionary containing trade data from API + symbol: Optional symbol to use if not in data + + Returns: + TradeModel instance + """ + return TradeModel( + timestamp=data.get("t", ""), + symbol=data.get("symbol", symbol or ""), + exchange=data.get("x", ""), + price=float(data.get("p", 0.0)), + size=int(data.get("s", 0)), + conditions=data.get("c", []), + id=int(data.get("i", 0)), + tape=data.get("z", ""), + ) + + +@dataclass +class LatestTradeModel: + """Model for latest trade data with symbol.""" + + trade: TradeModel + symbol: str + + +@dataclass +class TradesResponse: + """Response model for trades endpoint with pagination.""" + + trades: list[TradeModel] + symbol: str + next_page_token: str | None = None + + +def extract_trades_data(data: dict[str, Any]) -> dict[str, Any]: + """Extract and transform trades data from API response. + + Args: + data: Raw API response data + + Returns: + Transformed dictionary ready for model creation + """ + # Handle both single trade and multiple trades response formats + if "trades" in data: + # Multiple trades response + return data + if "trade" in data: + # Single latest trade response + return {"trades": [data["trade"]], "symbol": data.get("symbol", "")} + # Direct trade data + return {"trades": [data], "symbol": data.get("symbol", "")} diff --git a/src/py_alpaca_api/stock/__init__.py b/src/py_alpaca_api/stock/__init__.py index 33c4b37..5ac6229 100644 --- a/src/py_alpaca_api/stock/__init__.py +++ b/src/py_alpaca_api/stock/__init__.py @@ -1,8 +1,11 @@ from py_alpaca_api.stock.assets import Assets from py_alpaca_api.stock.history import History from py_alpaca_api.stock.latest_quote import LatestQuote +from py_alpaca_api.stock.metadata import Metadata from py_alpaca_api.stock.predictor import Predictor from py_alpaca_api.stock.screener import Screener +from py_alpaca_api.stock.snapshots import Snapshots +from py_alpaca_api.stock.trades import Trades from py_alpaca_api.trading.market import Market @@ -25,13 +28,6 @@ def __init__( headers=headers, base_url=base_url, data_url=data_url, market=market ) - self._initialize_components( - headers=headers, - base_url=base_url, - data_url=data_url, - market=market, - ) - def _initialize_components( self, headers: dict[str, str], @@ -46,3 +42,6 @@ def _initialize_components( ) self.predictor = Predictor(history=self.history, screener=self.screener) self.latest_quote = LatestQuote(headers=headers) + self.metadata = Metadata(headers=headers) + self.snapshots = Snapshots(headers=headers) + self.trades = Trades(headers=headers) diff --git a/src/py_alpaca_api/stock/history.py b/src/py_alpaca_api/stock/history.py index 787ad90..fd84f40 100644 --- a/src/py_alpaca_api/stock/history.py +++ b/src/py_alpaca_api/stock/history.py @@ -1,5 +1,6 @@ import json from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor, as_completed import pandas as pd @@ -9,6 +10,8 @@ class History: + BATCH_SIZE = 200 # Alpaca API limit for multi-symbol requests + def __init__(self, data_url: str, headers: dict[str, str], asset: Assets) -> None: """Initializes an instance of the History class. @@ -51,7 +54,7 @@ def check_if_stock(self, symbol: str) -> AssetModel: ########################################### def get_stock_data( self, - symbol: str, + symbol: str | list[str], start: str, end: str, timeframe: str = "1d", @@ -61,28 +64,59 @@ def get_stock_data( sort: str = "asc", adjustment: str = "raw", ) -> pd.DataFrame: - """Retrieves historical stock data for a given symbol within a specified date range and timeframe. + """Retrieves historical stock data for one or more symbols within a specified date range and timeframe. Args: - symbol: The stock symbol to fetch data for. + symbol: The stock symbol(s) to fetch data for. Can be a single symbol string or list of symbols. start: The start date for historical data in the format "YYYY-MM-DD". end: The end date for historical data in the format "YYYY-MM-DD". timeframe: The timeframe for the historical data. Default is "1d". feed: The data feed source. Default is "sip". currency: The currency for historical data. Default is "USD". - limit: The number of data points to fetch. Default is 1000. + limit: The number of data points to fetch per symbol. Default is 1000. sort: The sort order for the data. Default is "asc". adjustment: The adjustment for historical data. Default is "raw". Returns: - A pandas DataFrame containing the historical stock data for the given symbol and time range. + A pandas DataFrame containing the historical stock data for the given symbol(s) and time range. Raises: ValueError: If the given timeframe is not one of the allowed values. """ - self.check_if_stock(symbol) + # Handle single symbol or list of symbols + is_single = isinstance(symbol, str) + if is_single: + assert isinstance(symbol, str) # Type guard for mypy + symbols_list: list[str] = [symbol] + single_symbol: str = symbol + else: + assert isinstance(symbol, list) # Type guard for mypy + symbols_list = symbol + single_symbol = "" # Won't be used in multi-symbol case + + # Validate symbols are stocks + for sym in symbols_list: + self.check_if_stock(sym) - url = f"{self.data_url}/stocks/{symbol}/bars" + # If more than BATCH_SIZE symbols, need to batch the requests + if not is_single and len(symbols_list) > self.BATCH_SIZE: + return self._get_batched_stock_data( + symbols_list, + start, + end, + timeframe, + feed, + currency, + limit, + sort, + adjustment, + ) + + # Determine if using single or multi-symbol endpoint + if is_single: + url = f"{self.data_url}/stocks/{single_symbol}/bars" + else: + url = f"{self.data_url}/stocks/bars" timeframe_mapping: dict = { "1m": "1Min", @@ -111,8 +145,105 @@ def get_stock_data( "feed": feed, "sort": sort, } - symbol_data = self.get_historical_data(symbol, url, params) - return self.preprocess_data(symbol_data, symbol) + + # Add symbols parameter for multi-symbol request + if not is_single: + params["symbols"] = ",".join(symbols_list) + + symbol_data = self.get_historical_data(symbols_list, url, params, is_single) + + # Process data based on single or multi-symbol + if is_single: + return self.preprocess_data(symbol_data[single_symbol], single_symbol) + return self.preprocess_multi_data(symbol_data) + + def _get_batched_stock_data( + self, + symbols: list[str], + start: str, + end: str, + timeframe: str, + feed: str, + currency: str, + limit: int, + sort: str, + adjustment: str, + ) -> pd.DataFrame: + """Handle large symbol lists by batching requests. + + Args: + symbols: List of symbols to fetch data for. + start: The start date for historical data. + end: The end date for historical data. + timeframe: The timeframe for the historical data. + feed: The data feed source. + currency: The currency for historical data. + limit: The number of data points to fetch per symbol. + sort: The sort order for the data. + adjustment: The adjustment for historical data. + + Returns: + A pandas DataFrame containing the historical stock data for all symbols. + """ + # Split symbols into batches + batches = [ + symbols[i : i + self.BATCH_SIZE] + for i in range(0, len(symbols), self.BATCH_SIZE) + ] + + # Use ThreadPoolExecutor for concurrent batch requests + all_dfs = [] + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [] + for batch in batches: + future = executor.submit( + self.get_stock_data, + batch, + start, + end, + timeframe, + feed, + currency, + limit, + sort, + adjustment, + ) + futures.append(future) + + for future in as_completed(futures): + try: + df = future.result() + if not df.empty: + all_dfs.append(df) + except Exception as e: + # Log error but continue with other batches + print(f"Error fetching batch: {e}") + + if all_dfs: + return pd.concat(all_dfs, ignore_index=True).sort_values(["symbol", "date"]) + return pd.DataFrame() + + @staticmethod + def preprocess_multi_data( + symbols_data: dict[str, list[defaultdict]], + ) -> pd.DataFrame: + """Preprocess data for multiple symbols. + + Args: + symbols_data: A dictionary mapping symbols to their bar data. + + Returns: + A pandas DataFrame containing the preprocessed historical stock data for all symbols. + """ + all_dfs = [] + for symbol, data in symbols_data.items(): + if data: # Only process if data exists + df = History.preprocess_data(data, symbol) + all_dfs.append(df) + + if all_dfs: + return pd.concat(all_dfs, ignore_index=True).sort_values(["symbol", "date"]) + return pd.DataFrame() ########################################### # /////////// PreProcess Data \\\\\\\\\\\ # @@ -169,35 +300,51 @@ def preprocess_data(symbol_data: list[defaultdict], symbol: str) -> pd.DataFrame # ///////// Get Historical Data \\\\\\\\\ # ########################################### def get_historical_data( - self, symbol: str, url: str, params: dict - ) -> list[defaultdict]: - """Retrieves historical data for a given symbol. + self, symbols: list[str], url: str, params: dict, is_single: bool + ) -> dict[str, list[defaultdict]]: + """Retrieves historical data for given symbol(s). Args: - symbol (str): The symbol for which to retrieve historical data. - url (str): The URL to send the request to. - params (dict): Additional parameters to include in the request. + symbols: List of symbols for which to retrieve historical data. + url: The URL to send the request to. + params: Additional parameters to include in the request. + is_single: Whether this is a single-symbol request. Returns: - list[defaultdict]: A list of historical data for the given symbol. + dict[str, list[defaultdict]]: A dictionary mapping symbols to their historical data. """ page_token = None symbols_data = defaultdict(list) + while True: - params["page_token"] = page_token + if page_token: + params["page_token"] = page_token + response = json.loads( Requests() .request(method="GET", url=url, headers=self.headers, params=params) .text ) - if not response.get("bars"): - raise Exception( - f"No historical data found for {symbol}, with the given parameters." - ) + # Handle single vs multi-symbol response format + if is_single: + if not response.get("bars"): + raise Exception( + f"No historical data found for {symbols[0]}, with the given parameters." + ) + symbols_data[symbols[0]].extend(response.get("bars", [])) + else: + # Multi-symbol response has bars nested under symbol keys + bars = response.get("bars", {}) + if not bars: + raise Exception( + f"No historical data found for symbols: {', '.join(symbols)}, with the given parameters." + ) + for symbol, symbol_bars in bars.items(): + symbols_data[symbol].extend(symbol_bars) - symbols_data[symbol].extend(response.get("bars", [])) - page_token = response.get("next_page_token", "") + page_token = response.get("next_page_token") if not page_token: break - return symbols_data[symbol] + + return symbols_data diff --git a/src/py_alpaca_api/stock/latest_quote.py b/src/py_alpaca_api/stock/latest_quote.py index 7751e38..d716a97 100644 --- a/src/py_alpaca_api/stock/latest_quote.py +++ b/src/py_alpaca_api/stock/latest_quote.py @@ -1,10 +1,13 @@ import json +from concurrent.futures import ThreadPoolExecutor, as_completed from py_alpaca_api.http.requests import Requests from py_alpaca_api.models.quote_model import QuoteModel, quote_class_from_dict class LatestQuote: + BATCH_SIZE = 200 # Alpaca API limit for multi-symbol requests + def __init__(self, headers: dict[str, str]) -> None: self.headers = headers @@ -14,6 +17,19 @@ def get( feed: str = "iex", currency: str = "USD", ) -> list[QuoteModel] | QuoteModel: + """Get latest quotes for one or more symbols. + + Args: + symbol: A string or list of strings representing the stock symbol(s). + feed: The data feed source. Default is "iex". + currency: The currency for the quotes. Default is "USD". + + Returns: + A single QuoteModel or list of QuoteModel objects. + + Raises: + ValueError: If symbol is None/empty or if feed is invalid. + """ if symbol is None or symbol == "": raise ValueError("Symbol is required. Must be a string or list of strings.") @@ -21,15 +37,44 @@ def get( if feed not in valid_feeds: raise ValueError("Invalid feed, must be one of: 'iex', 'sip', 'otc'") - if isinstance(symbol, list): - symbol = ",".join(symbol).replace(" ", "").upper() + # Handle single vs multiple symbols + is_single = isinstance(symbol, str) + if is_single: + assert isinstance(symbol, str) # Type guard for mypy + symbols = [symbol.upper().strip()] else: - symbol = symbol.replace(" ", "").upper() + assert isinstance(symbol, list) # Type guard for mypy + symbols = [s.upper().strip() for s in symbol] + + # If more than BATCH_SIZE symbols, need to batch the requests + if len(symbols) > self.BATCH_SIZE: + quotes = self._get_batched_quotes(symbols, feed, currency) + else: + quotes = self._fetch_quotes(symbols, feed, currency) + + # Return single quote if single symbol requested + if is_single and quotes: + return quotes[0] + return quotes + + def _fetch_quotes( + self, symbols: list[str], feed: str, currency: str + ) -> list[QuoteModel]: + """Fetch quotes for a list of symbols. + Args: + symbols: List of stock symbols. + feed: The data feed source. + currency: The currency for the quotes. + + Returns: + List of QuoteModel objects. + """ url = "https://data.alpaca.markets/v2/stocks/quotes/latest" + symbols_str = ",".join(symbols) params: dict[str, str | bool | float | int] = { - "symbols": symbol, + "symbols": symbols_str, "feed": feed, "currency": currency, } @@ -41,8 +86,7 @@ def get( ) quotes = [] - - for key, value in response["quotes"].items(): + for key, value in response.get("quotes", {}).items(): quotes.append( quote_class_from_dict( { @@ -56,4 +100,41 @@ def get( ) ) - return quotes if len(quotes) > 1 else quotes[0] + return quotes + + def _get_batched_quotes( + self, symbols: list[str], feed: str, currency: str + ) -> list[QuoteModel]: + """Handle large symbol lists by batching requests. + + Args: + symbols: List of stock symbols. + feed: The data feed source. + currency: The currency for the quotes. + + Returns: + List of QuoteModel objects. + """ + # Split symbols into batches + batches = [ + symbols[i : i + self.BATCH_SIZE] + for i in range(0, len(symbols), self.BATCH_SIZE) + ] + + # Use ThreadPoolExecutor for concurrent batch requests + all_quotes = [] + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [] + for batch in batches: + future = executor.submit(self._fetch_quotes, batch, feed, currency) + futures.append(future) + + for future in as_completed(futures): + try: + quotes = future.result() + all_quotes.extend(quotes) + except Exception as e: + # Log error but continue with other batches + print(f"Error fetching batch: {e}") + + return all_quotes diff --git a/src/py_alpaca_api/stock/metadata.py b/src/py_alpaca_api/stock/metadata.py new file mode 100644 index 0000000..aa2542c --- /dev/null +++ b/src/py_alpaca_api/stock/metadata.py @@ -0,0 +1,187 @@ +import json + +from py_alpaca_api.exceptions import APIRequestError, ValidationError +from py_alpaca_api.http.requests import Requests + + +class Metadata: + """Market metadata API for condition codes and exchange codes.""" + + def __init__(self, headers: dict[str, str]) -> None: + """Initialize the Metadata class. + + Args: + headers: Dictionary containing authentication headers. + """ + self.headers = headers + self.base_url = "https://data.alpaca.markets/v2/stocks/meta" + # Cache for metadata that rarely changes + self._exchange_cache: dict[str, str] | None = None + self._condition_cache: dict[str, dict[str, str]] = {} + + def get_exchange_codes(self, use_cache: bool = True) -> dict[str, str]: + """Get the mapping between exchange codes and exchange names. + + Args: + use_cache: Whether to use cached data if available. Defaults to True. + + Returns: + Dictionary mapping exchange codes to exchange names. + + Raises: + APIRequestError: If the API request fails. + """ + if use_cache and self._exchange_cache is not None: + return self._exchange_cache + + url = f"{self.base_url}/exchanges" + + try: + response = json.loads( + Requests().request(method="GET", url=url, headers=self.headers).text + ) + except Exception as e: + raise APIRequestError(message=f"Failed to get exchange codes: {e!s}") from e + + if not response: + raise APIRequestError(message="No exchange data returned") + + # Cache the result + self._exchange_cache = response + return response + + def get_condition_codes( + self, + ticktype: str = "trade", + tape: str = "A", + use_cache: bool = True, + ) -> dict[str, str]: + """Get the mapping between condition codes and condition names. + + Args: + ticktype: Type of conditions to retrieve ("trade" or "quote"). Defaults to "trade". + tape: Market tape ("A" for NYSE, "B" for NASDAQ, "C" for other). Defaults to "A". + use_cache: Whether to use cached data if available. Defaults to True. + + Returns: + Dictionary mapping condition codes to condition descriptions. + + Raises: + ValidationError: If invalid parameters are provided. + APIRequestError: If the API request fails. + """ + # Validate parameters + valid_ticktypes = ["trade", "quote"] + if ticktype not in valid_ticktypes: + raise ValidationError( + f"Invalid ticktype. Must be one of: {', '.join(valid_ticktypes)}" + ) + + valid_tapes = ["A", "B", "C"] + if tape not in valid_tapes: + raise ValidationError( + f"Invalid tape. Must be one of: {', '.join(valid_tapes)}" + ) + + # Check cache + cache_key = f"{ticktype}_{tape}" + if use_cache and cache_key in self._condition_cache: + return self._condition_cache[cache_key] + + url = f"{self.base_url}/conditions/{ticktype}" + params: dict[str, str | bool | float | int] = {"tape": tape} + + try: + response = json.loads( + Requests() + .request(method="GET", url=url, headers=self.headers, params=params) + .text + ) + except Exception as e: + raise APIRequestError( + message=f"Failed to get condition codes: {e!s}" + ) from e + + if response is None: + raise APIRequestError(message="No condition data returned") + + # Cache the result + self._condition_cache[cache_key] = response + return response + + def get_all_condition_codes( + self, use_cache: bool = True + ) -> dict[str, dict[str, dict[str, str]]]: + """Get all condition codes for all tick types and tapes. + + Args: + use_cache: Whether to use cached data if available. Defaults to True. + + Returns: + Nested dictionary with structure: + { + "trade": { + "A": {condition_code: description, ...}, + "B": {condition_code: description, ...}, + "C": {condition_code: description, ...} + }, + "quote": { + "A": {condition_code: description, ...}, + "B": {condition_code: description, ...}, + "C": {condition_code: description, ...} + } + } + + Raises: + APIRequestError: If any API request fails. + """ + result: dict[str, dict[str, dict[str, str]]] = {} + + for ticktype in ["trade", "quote"]: + result[ticktype] = {} + for tape in ["A", "B", "C"]: + try: + result[ticktype][tape] = self.get_condition_codes( + ticktype=ticktype, tape=tape, use_cache=use_cache + ) + except APIRequestError: + # Some tape/ticktype combinations might not be available + result[ticktype][tape] = {} + + return result + + def clear_cache(self) -> None: + """Clear all cached metadata. + + This forces the next request to fetch fresh data from the API. + """ + self._exchange_cache = None + self._condition_cache = {} + + def lookup_exchange(self, code: str) -> str | None: + """Look up an exchange name by its code. + + Args: + code: The exchange code to look up. + + Returns: + The exchange name if found, None otherwise. + """ + exchanges = self.get_exchange_codes() + return exchanges.get(code) + + def lookup_condition( + self, code: str, ticktype: str = "trade", tape: str = "A" + ) -> str | None: + """Look up a condition description by its code. + + Args: + code: The condition code to look up. + ticktype: Type of condition ("trade" or "quote"). Defaults to "trade". + tape: Market tape ("A", "B", or "C"). Defaults to "A". + + Returns: + The condition description if found, None otherwise. + """ + conditions = self.get_condition_codes(ticktype=ticktype, tape=tape) + return conditions.get(code) diff --git a/src/py_alpaca_api/stock/snapshots.py b/src/py_alpaca_api/stock/snapshots.py new file mode 100644 index 0000000..23584b9 --- /dev/null +++ b/src/py_alpaca_api/stock/snapshots.py @@ -0,0 +1,141 @@ +import json + +from py_alpaca_api.exceptions import APIRequestError, ValidationError +from py_alpaca_api.http.requests import Requests +from py_alpaca_api.models.snapshot_model import SnapshotModel, snapshot_class_from_dict + + +class Snapshots: + def __init__(self, headers: dict[str, str]) -> None: + """Initialize the Snapshots class. + + Args: + headers: Dictionary containing authentication headers. + """ + self.headers = headers + self.base_url = "https://data.alpaca.markets/v2/stocks" + + def get_snapshot( + self, + symbol: str, + feed: str = "iex", + ) -> SnapshotModel: + """Get a snapshot of a single stock symbol. + + The snapshot includes the latest trade, latest quote, minute bar, + daily bar, and previous daily bar data. + + Args: + symbol: The stock symbol to get snapshot for. + feed: The data feed to use ("iex", "sip", or "otc"). Defaults to "iex". + + Returns: + A SnapshotModel containing the snapshot data. + + Raises: + ValidationError: If symbol is invalid or feed is invalid. + APIRequestError: If the API request fails. + """ + if not symbol or not isinstance(symbol, str): + raise ValidationError("Symbol is required and must be a string.") + + valid_feeds = ["iex", "sip", "otc"] + if feed not in valid_feeds: + raise ValidationError( + f"Invalid feed. Must be one of: {', '.join(valid_feeds)}" + ) + + symbol = symbol.upper().strip() + + url = f"{self.base_url}/{symbol}/snapshot" + + params: dict[str, str | bool | float | int] = {"feed": feed} + + try: + response = json.loads( + Requests() + .request(method="GET", url=url, headers=self.headers, params=params) + .text + ) + except Exception as e: + raise APIRequestError( + message=f"Failed to get snapshot for {symbol}: {e!s}" + ) from e + + if not response: + raise APIRequestError(message=f"No snapshot data returned for {symbol}") + + response["symbol"] = symbol + return snapshot_class_from_dict(response) + + def get_snapshots( + self, + symbols: list[str] | str, + feed: str = "iex", + ) -> list[SnapshotModel] | dict[str, SnapshotModel]: + """Get snapshots for multiple stock symbols. + + The snapshot includes the latest trade, latest quote, minute bar, + daily bar, and previous daily bar data for each symbol. + + Args: + symbols: A list of stock symbols or comma-separated string of symbols. + feed: The data feed to use ("iex", "sip", or "otc"). Defaults to "iex". + + Returns: + A dictionary mapping symbols to their SnapshotModel objects, or a list + of SnapshotModel objects if only one symbol is provided. + + Raises: + ValidationError: If symbols are invalid or feed is invalid. + APIRequestError: If the API request fails. + """ + if not symbols: + raise ValidationError("Symbols are required.") + + valid_feeds = ["iex", "sip", "otc"] + if feed not in valid_feeds: + raise ValidationError( + f"Invalid feed. Must be one of: {', '.join(valid_feeds)}" + ) + + if isinstance(symbols, str): + symbols_str = symbols.upper().strip() + symbols_list = [s.strip() for s in symbols_str.split(",")] + else: + symbols_list = [s.upper().strip() for s in symbols] + symbols_str = ",".join(symbols_list) + + if not symbols_str: + raise ValidationError("At least one symbol is required.") + + url = f"{self.base_url}/snapshots" + + params: dict[str, str | bool | float | int] = { + "symbols": symbols_str, + "feed": feed, + } + + try: + response = json.loads( + Requests() + .request(method="GET", url=url, headers=self.headers, params=params) + .text + ) + except Exception as e: + raise APIRequestError(message=f"Failed to get snapshots: {e!s}") from e + + if not response: + raise APIRequestError(message="No snapshot data returned") + + # The API returns symbols as top-level keys directly + snapshots = {} + for symbol, data in response.items(): + if isinstance(data, dict): # Ensure it's snapshot data + data["symbol"] = symbol + snapshots[symbol] = snapshot_class_from_dict(data) + + if len(symbols_list) == 1: + return list(snapshots.values()) + + return snapshots diff --git a/src/py_alpaca_api/stock/trades.py b/src/py_alpaca_api/stock/trades.py new file mode 100644 index 0000000..2cd58bd --- /dev/null +++ b/src/py_alpaca_api/stock/trades.py @@ -0,0 +1,369 @@ +import json +from datetime import datetime +from typing import Literal + +from py_alpaca_api.exceptions import APIRequestError, ValidationError +from py_alpaca_api.http.requests import Requests +from py_alpaca_api.models.trade_model import ( + TradeModel, + TradesResponse, + trade_class_from_dict, +) + + +class Trades: + def __init__(self, headers: dict[str, str]) -> None: + self.headers = headers + self.base_url = "https://data.alpaca.markets/v2" + + def get_trades( + self, + symbol: str, + start: str, + end: str, + limit: int = 1000, + feed: Literal["iex", "sip", "otc"] | None = None, + page_token: str | None = None, + asof: str | None = None, + ) -> TradesResponse: + """Retrieve historical trades for a symbol. + + Args: + symbol: The stock symbol to retrieve trades for + start: Start time in RFC-3339 format (YYYY-MM-DDTHH:MM:SSZ) + end: End time in RFC-3339 format (YYYY-MM-DDTHH:MM:SSZ) + limit: Number of trades to return (1-10000, default 1000) + feed: Data feed to use (iex, sip, otc) + page_token: Token for pagination + asof: As-of time for historical data in RFC-3339 format + + Returns: + TradesResponse with list of trades and pagination token + + Raises: + ValidationError: If parameters are invalid + APIRequestError: If the API request fails + """ + # Validate parameters + if not symbol: + raise ValidationError("Symbol is required") + + if limit < 1 or limit > 10000: + raise ValidationError("Limit must be between 1 and 10000") + + # Validate date formats (must include time) + try: + if "T" not in start or "T" not in end: + raise ValueError("Date must include time (RFC-3339 format)") + datetime.fromisoformat(start.replace("Z", "+00:00")) + datetime.fromisoformat(end.replace("Z", "+00:00")) + except (ValueError, AttributeError) as e: + raise ValidationError( + f"Invalid date format. Use RFC-3339 format (YYYY-MM-DDTHH:MM:SSZ): {e}" + ) from e + + # Build query parameters + params: dict[str, str | bool | float | int] = { + "start": start, + "end": end, + "limit": limit, + } + + if feed: + params["feed"] = feed + if page_token: + params["page_token"] = page_token + if asof: + params["asof"] = asof + + # Make request + url = f"{self.base_url}/stocks/{symbol}/trades" + http_response = Requests().request( + "GET", url, headers=self.headers, params=params + ) + + if http_response.status_code != 200: + raise APIRequestError( + http_response.status_code, + f"Failed to retrieve trades: {http_response.text}", + ) + + response = json.loads(http_response.text) if http_response.text else {} + + # Parse trades + trades = [] + for trade_data in response.get("trades", []) or []: + trades.append(trade_class_from_dict(trade_data, symbol)) + + return TradesResponse( + trades=trades, + symbol=response.get("symbol", symbol), + next_page_token=response.get("next_page_token"), + ) + + def get_latest_trade( + self, + symbol: str, + feed: Literal["iex", "sip", "otc"] | None = None, + asof: str | None = None, + ) -> TradeModel: + """Get the latest trade for a symbol. + + Args: + symbol: The stock symbol to retrieve latest trade for + feed: Data feed to use (iex, sip, otc) + asof: As-of time for historical data in RFC-3339 format + + Returns: + TradeModel with the latest trade data + + Raises: + ValidationError: If symbol is invalid + APIRequestError: If the API request fails + """ + if not symbol: + raise ValidationError("Symbol is required") + + # Build query parameters + params: dict[str, str | bool | float | int] = {"symbols": symbol} + + if feed: + params["feed"] = feed + if asof: + params["asof"] = asof + + # Make request + url = f"{self.base_url}/stocks/trades/latest" + http_response = Requests().request( + "GET", url, headers=self.headers, params=params + ) + + if http_response.status_code != 200: + raise APIRequestError( + http_response.status_code, + f"Failed to retrieve latest trade: {http_response.text}", + ) + + response = json.loads(http_response.text) + + # Handle response format + if "trades" in response and symbol in response["trades"]: + trade_data = response["trades"][symbol] + elif symbol in response: + trade_data = response[symbol] + else: + raise APIRequestError( + 404, + f"No trade data found for symbol: {symbol}", + ) + + return trade_class_from_dict(trade_data, symbol) + + def get_trades_multi( + self, + symbols: list[str], + start: str, + end: str, + limit: int = 1000, + feed: Literal["iex", "sip", "otc"] | None = None, + page_token: str | None = None, + asof: str | None = None, + ) -> dict[str, TradesResponse]: + """Retrieve historical trades for multiple symbols. + + Args: + symbols: List of stock symbols (max 100) + start: Start time in RFC-3339 format + end: End time in RFC-3339 format + limit: Number of trades per symbol (1-10000, default 1000) + feed: Data feed to use + page_token: Token for pagination + asof: As-of time for historical data + + Returns: + Dictionary mapping symbols to TradesResponse objects + + Raises: + ValidationError: If parameters are invalid + APIRequestError: If the API request fails + """ + if not symbols: + raise ValidationError("At least one symbol is required") + + if len(symbols) > 100: + raise ValidationError("Maximum 100 symbols allowed") + + if limit < 1 or limit > 10000: + raise ValidationError("Limit must be between 1 and 10000") + + # Validate date formats (must include time) + try: + if "T" not in start or "T" not in end: + raise ValueError("Date must include time (RFC-3339 format)") + datetime.fromisoformat(start.replace("Z", "+00:00")) + datetime.fromisoformat(end.replace("Z", "+00:00")) + except (ValueError, AttributeError) as e: + raise ValidationError( + f"Invalid date format. Use RFC-3339 format (YYYY-MM-DDTHH:MM:SSZ): {e}" + ) from e + + # Build query parameters + params: dict[str, str | bool | float | int] = { + "symbols": ",".join(symbols), + "start": start, + "end": end, + "limit": limit, + } + + if feed: + params["feed"] = feed + if page_token: + params["page_token"] = page_token + if asof: + params["asof"] = asof + + # Make request + url = f"{self.base_url}/stocks/trades" + http_response = Requests().request( + "GET", url, headers=self.headers, params=params + ) + + if http_response.status_code != 200: + raise APIRequestError( + http_response.status_code, + f"Failed to retrieve trades: {http_response.text}", + ) + + response = json.loads(http_response.text) + + # Parse response for each symbol + result = {} + trades_data = response.get("trades", {}) + next_page_token = response.get("next_page_token") + + for symbol in symbols: + if symbol in trades_data: + trades = [ + trade_class_from_dict(trade, symbol) + for trade in trades_data[symbol] + ] + result[symbol] = TradesResponse( + trades=trades, + symbol=symbol, + next_page_token=next_page_token, + ) + else: + # Symbol had no trades in the time period + result[symbol] = TradesResponse( + trades=[], + symbol=symbol, + next_page_token=None, + ) + + return result + + def get_latest_trades_multi( + self, + symbols: list[str], + feed: Literal["iex", "sip", "otc"] | None = None, + asof: str | None = None, + ) -> dict[str, TradeModel]: + """Get latest trades for multiple symbols. + + Args: + symbols: List of stock symbols (max 100) + feed: Data feed to use + asof: As-of time for historical data + + Returns: + Dictionary mapping symbols to their latest TradeModel + + Raises: + ValidationError: If parameters are invalid + APIRequestError: If the API request fails + """ + if not symbols: + raise ValidationError("At least one symbol is required") + + if len(symbols) > 100: + raise ValidationError("Maximum 100 symbols allowed") + + # Build query parameters + params: dict[str, str | bool | float | int] = {"symbols": ",".join(symbols)} + + if feed: + params["feed"] = feed + if asof: + params["asof"] = asof + + # Make request + url = f"{self.base_url}/stocks/trades/latest" + http_response = Requests().request( + "GET", url, headers=self.headers, params=params + ) + + if http_response.status_code != 200: + raise APIRequestError( + http_response.status_code, + f"Failed to retrieve latest trades: {http_response.text}", + ) + + response = json.loads(http_response.text) + + # Parse response + result = {} + trades_data = response.get("trades", response) + + for symbol in symbols: + if symbol in trades_data: + result[symbol] = trade_class_from_dict(trades_data[symbol], symbol) + + return result + + def get_all_trades( + self, + symbol: str, + start: str, + end: str, + feed: Literal["iex", "sip", "otc"] | None = None, + asof: str | None = None, + ) -> list[TradeModel]: + """Retrieve all trades for a symbol with automatic pagination. + + Args: + symbol: The stock symbol + start: Start time in RFC-3339 format + end: End time in RFC-3339 format + feed: Data feed to use + asof: As-of time for historical data + + Returns: + List of all TradeModel objects across all pages + + Raises: + ValidationError: If parameters are invalid + APIRequestError: If the API request fails + """ + all_trades = [] + page_token = None + + while True: + response = self.get_trades( + symbol=symbol, + start=start, + end=end, + limit=10000, # Max limit for efficiency + feed=feed, + page_token=page_token, + asof=asof, + ) + + all_trades.extend(response.trades) + + # Check if there are more pages + if response.next_page_token: + page_token = response.next_page_token + else: + break + + return all_trades diff --git a/src/py_alpaca_api/trading/__init__.py b/src/py_alpaca_api/trading/__init__.py index 709aaeb..cecd9d8 100644 --- a/src/py_alpaca_api/trading/__init__.py +++ b/src/py_alpaca_api/trading/__init__.py @@ -1,4 +1,5 @@ from py_alpaca_api.trading.account import Account +from py_alpaca_api.trading.corporate_actions import CorporateActions from py_alpaca_api.trading.market import Market from py_alpaca_api.trading.news import News from py_alpaca_api.trading.orders import Orders @@ -23,6 +24,7 @@ def __init__(self, api_key: str, api_secret: str, api_paper: bool) -> None: def _initialize_components(self, headers: dict[str, str], base_url: str): self.account = Account(headers=headers, base_url=base_url) + self.corporate_actions = CorporateActions(headers=headers, base_url=base_url) self.market = Market(headers=headers, base_url=base_url) self.positions = Positions( headers=headers, base_url=base_url, account=self.account diff --git a/src/py_alpaca_api/trading/account.py b/src/py_alpaca_api/trading/account.py index b03033f..248a88c 100644 --- a/src/py_alpaca_api/trading/account.py +++ b/src/py_alpaca_api/trading/account.py @@ -8,6 +8,10 @@ AccountActivityModel, account_activity_class_from_dict, ) +from py_alpaca_api.models.account_config_model import ( + AccountConfigModel, + account_config_class_from_dict, +) from py_alpaca_api.models.account_model import AccountModel, account_class_from_dict @@ -183,3 +187,110 @@ def portfolio_history( # Ensure we always return a DataFrame assert isinstance(portfolio_df, pd.DataFrame) return portfolio_df + + ############################################ + # Get Account Configuration + ############################################ + def get_configuration(self) -> AccountConfigModel: + """Retrieves the current account configuration settings. + + Returns: + AccountConfigModel: The current account configuration. + + Raises: + APIRequestError: If the request to retrieve configuration fails. + """ + url = f"{self.base_url}/account/configurations" + http_response = Requests().request("GET", url, headers=self.headers) + + if http_response.status_code != 200: + raise APIRequestError( + http_response.status_code, + f"Failed to retrieve account configuration: {http_response.status_code}", + ) + + response = json.loads(http_response.text) + return account_config_class_from_dict(response) + + ############################################ + # Update Account Configuration + ############################################ + def update_configuration( + self, + dtbp_check: str | None = None, + fractional_trading: bool | None = None, + max_margin_multiplier: str | None = None, + no_shorting: bool | None = None, + pdt_check: str | None = None, + ptp_no_exception_entry: bool | None = None, + suspend_trade: bool | None = None, + trade_confirm_email: str | None = None, + ) -> AccountConfigModel: + """Updates the account configuration settings. + + Args: + dtbp_check: Day trade buying power check ("entry", "exit", "both") + fractional_trading: Whether to enable fractional trading + max_margin_multiplier: Maximum margin multiplier ("1", "2", "4") + no_shorting: Whether to disable short selling + pdt_check: Pattern day trader check ("entry", "exit", "both") + ptp_no_exception_entry: Whether to enable PTP no exception entry + suspend_trade: Whether to suspend trading + trade_confirm_email: Trade confirmation emails ("all", "none") + + Returns: + AccountConfigModel: The updated account configuration. + + Raises: + APIRequestError: If the request to update configuration fails. + ValueError: If invalid parameter values are provided. + """ + # Validate parameters using a validation map + validations = { + "dtbp_check": (dtbp_check, ["entry", "exit", "both"]), + "pdt_check": (pdt_check, ["entry", "exit", "both"]), + "max_margin_multiplier": (max_margin_multiplier, ["1", "2", "4"]), + "trade_confirm_email": (trade_confirm_email, ["all", "none"]), + } + + for param_name, (value, valid_values) in validations.items(): + if value and value not in valid_values: + raise ValueError( + f"{param_name} must be one of: {', '.join(valid_values)}" + ) + + # Build request body with only provided parameters + body: dict[str, str | bool] = {} + if dtbp_check is not None: + body["dtbp_check"] = dtbp_check + if fractional_trading is not None: + body["fractional_trading"] = fractional_trading + if max_margin_multiplier is not None: + body["max_margin_multiplier"] = max_margin_multiplier + if no_shorting is not None: + body["no_shorting"] = no_shorting + if pdt_check is not None: + body["pdt_check"] = pdt_check + if ptp_no_exception_entry is not None: + body["ptp_no_exception_entry"] = ptp_no_exception_entry + if suspend_trade is not None: + body["suspend_trade"] = suspend_trade + if trade_confirm_email is not None: + body["trade_confirm_email"] = trade_confirm_email + + if not body: + raise ValueError("At least one configuration parameter must be provided") + + url = f"{self.base_url}/account/configurations" + http_response = Requests().request( + "PATCH", url, headers=self.headers, json=body + ) + + if http_response.status_code != 200: + raise APIRequestError( + http_response.status_code, + f"Failed to update account configuration: {http_response.status_code}", + ) + + response = json.loads(http_response.text) + return account_config_class_from_dict(response) diff --git a/src/py_alpaca_api/trading/corporate_actions.py b/src/py_alpaca_api/trading/corporate_actions.py new file mode 100644 index 0000000..81cd4be --- /dev/null +++ b/src/py_alpaca_api/trading/corporate_actions.py @@ -0,0 +1,186 @@ +import json +from datetime import datetime +from typing import Literal + +from py_alpaca_api.exceptions import APIRequestError, ValidationError +from py_alpaca_api.http.requests import Requests +from py_alpaca_api.models.corporate_action_model import ( + CorporateActionModel, + corporate_action_class_from_dict, +) + + +class CorporateActions: + def __init__(self, headers: dict[str, str], base_url: str) -> None: + self.headers = headers + self.base_url = base_url + + def get_announcements( + self, + since: str, + until: str, + ca_types: list[str], + symbol: str | None = None, + cusip: str | None = None, + date_type: Literal["declaration_date", "ex_date", "record_date", "payable_date"] + | None = None, + page_limit: int = 100, + page_token: str | None = None, + ) -> list[CorporateActionModel]: + """Retrieve corporate action announcements. + + Args: + since: The start (inclusive) of the date range in YYYY-MM-DD format. + Date range is limited to 90 days. + until: The end (inclusive) of the date range in YYYY-MM-DD format. + Date range is limited to 90 days. + ca_types: List of corporate action types to return. + Valid types: dividend, merger, spinoff, split + symbol: Optional filter by symbol + cusip: Optional filter by CUSIP + date_type: Optional date type for filtering (declaration_date, ex_date, record_date, payable_date) + page_limit: Number of results per page (Note: API may return all results regardless) + page_token: Token for pagination (currently not used by API) + + Returns: + List of CorporateActionModel objects + + Raises: + ValidationError: If date range exceeds 90 days or invalid parameters + APIRequestError: If the API request fails + """ + # Validate date range + try: + since_date = datetime.strptime(since, "%Y-%m-%d") + until_date = datetime.strptime(until, "%Y-%m-%d") + date_diff = (until_date - since_date).days + + if date_diff > 90: + raise ValidationError("Date range cannot exceed 90 days") + if date_diff < 0: + raise ValidationError("'since' date must be before 'until' date") + except ValueError as e: + raise ValidationError(f"Invalid date format. Use YYYY-MM-DD: {e}") from e + + # Validate ca_types + valid_types = {"dividend", "merger", "spinoff", "split"} + for ca_type in ca_types: + if ca_type not in valid_types: + raise ValidationError( + f"Invalid corporate action type: {ca_type}. " + f"Valid types are: {', '.join(valid_types)}" + ) + + # Build query parameters + params: dict[str, str | bool | float | int] = { + "since": since, + "until": until, + "ca_types": ",".join(ca_types), + "page_limit": min(page_limit, 500), + } + + if symbol: + params["symbol"] = symbol + if cusip: + params["cusip"] = cusip + if date_type: + params["date_type"] = date_type + if page_token: + params["page_token"] = page_token + + # Make request + url = f"{self.base_url}/corporate_actions/announcements" + http_response = Requests().request( + "GET", url, headers=self.headers, params=params + ) + + if http_response.status_code != 200: + raise APIRequestError( + http_response.status_code, + f"Failed to retrieve corporate actions: {http_response.text}", + ) + + response = json.loads(http_response.text) + + # Handle response - it can be a list directly or an object with announcements + if isinstance(response, list): + announcements = response + else: + announcements = response.get("announcements", []) + + result = [] + for announcement in announcements: + result.append(corporate_action_class_from_dict(announcement)) + + # If there's a next_page_token, we could handle pagination here + # For now, return the current page + return result + + def get_announcement_by_id(self, announcement_id: str) -> CorporateActionModel: + """Retrieve a specific corporate action announcement by ID. + + Args: + announcement_id: The unique ID of the announcement + + Returns: + CorporateActionModel object + + Raises: + APIRequestError: If the API request fails or announcement not found + """ + url = f"{self.base_url}/corporate_actions/announcements/{announcement_id}" + http_response = Requests().request("GET", url, headers=self.headers) + + if http_response.status_code == 404: + raise APIRequestError( + 404, + f"Corporate action announcement not found: {announcement_id}", + ) + if http_response.status_code != 200: + raise APIRequestError( + http_response.status_code, + f"Failed to retrieve corporate action: {http_response.text}", + ) + + response = json.loads(http_response.text) + return corporate_action_class_from_dict(response) + + def get_all_announcements( + self, + since: str, + until: str, + ca_types: list[str], + symbol: str | None = None, + cusip: str | None = None, + date_type: Literal["declaration_date", "ex_date", "record_date", "payable_date"] + | None = None, + ) -> list[CorporateActionModel]: + """Retrieve all corporate action announcements. + + Note: The API currently returns all results within the date range + without pagination, so this method simply calls get_announcements. + + Args: + since: The start (inclusive) of the date range in YYYY-MM-DD format. + until: The end (inclusive) of the date range in YYYY-MM-DD format. + ca_types: List of corporate action types to return. + symbol: Optional filter by symbol + cusip: Optional filter by CUSIP + date_type: Optional date type for filtering + + Returns: + List of all CorporateActionModel objects + + Raises: + ValidationError: If date range exceeds 90 days or invalid parameters + APIRequestError: If the API request fails + """ + # API returns all results within date range, no pagination needed currently + return self.get_announcements( + since=since, + until=until, + ca_types=ca_types, + symbol=symbol, + cusip=cusip, + date_type=date_type, + ) diff --git a/src/py_alpaca_api/trading/orders.py b/src/py_alpaca_api/trading/orders.py index 8cf8d82..5eb77ab 100644 --- a/src/py_alpaca_api/trading/orders.py +++ b/src/py_alpaca_api/trading/orders.py @@ -88,6 +88,126 @@ def cancel_all(self) -> str: ) return f"{len(response)} orders have been cancelled" + ######################################################## + # \\\\\\\\\ Replace Order /////////////////////# + ######################################################## + def replace_order( + self, + order_id: str, + qty: float | None = None, + limit_price: float | None = None, + stop_price: float | None = None, + trail: float | None = None, + time_in_force: str | None = None, + client_order_id: str | None = None, + ) -> OrderModel: + """Replace an existing order with updated parameters. + + Args: + order_id: The ID of the order to replace. + qty: The new quantity for the order. + limit_price: The new limit price for limit orders. + stop_price: The new stop price for stop orders. + trail: The new trail amount for trailing stop orders (percent or price). + time_in_force: The new time in force for the order. + client_order_id: Optional client-assigned ID for the replacement order. + + Returns: + OrderModel: The replaced order. + + Raises: + ValidationError: If no parameters are provided to update. + APIRequestError: If the API request fails. + """ + # At least one parameter must be provided + if not any([qty, limit_price, stop_price, trail, time_in_force]): + raise ValidationError( + "At least one parameter must be provided to replace the order" + ) + + body: dict[str, str | float | None] = {} + if qty is not None: + body["qty"] = qty + if limit_price is not None: + body["limit_price"] = limit_price + if stop_price is not None: + body["stop_price"] = stop_price + if trail is not None: + body["trail"] = trail + if time_in_force is not None: + body["time_in_force"] = time_in_force + if client_order_id is not None: + body["client_order_id"] = client_order_id + + url = f"{self.base_url}/orders/{order_id}" + + response = json.loads( + Requests() + .request(method="PATCH", url=url, headers=self.headers, json=body) + .text + ) + return order_class_from_dict(response) + + ######################################################## + # \\\\\\\ Get Order By Client ID ////////////////# + ######################################################## + def get_by_client_order_id(self, client_order_id: str) -> OrderModel: + """Retrieves order information by client order ID. + + Note: This queries all orders and filters by client_order_id. + The Alpaca API doesn't have a direct endpoint for this. + + Args: + client_order_id: The client-assigned ID of the order to retrieve. + + Returns: + OrderModel: An object representing the order information. + + Raises: + APIRequestError: If the request fails or order not found. + ValidationError: If no order with given client_order_id is found. + """ + # Get all orders and filter by client_order_id + params: dict[str, str | bool | float | int] = {"status": "all", "limit": 500} + url = f"{self.base_url}/orders" + + response = json.loads( + Requests() + .request(method="GET", url=url, headers=self.headers, params=params) + .text + ) + + # Find the order with matching client_order_id + for order_data in response: + if order_data.get("client_order_id") == client_order_id: + return order_class_from_dict(order_data) + + raise ValidationError(f"No order found with client_order_id: {client_order_id}") + + ######################################################## + # \\\\\\ Cancel Order By Client ID ///////////////# + ######################################################## + def cancel_by_client_order_id(self, client_order_id: str) -> str: + """Cancel an order by its client order ID. + + Note: This first retrieves the order by client_order_id, then cancels by ID. + + Args: + client_order_id: The client-assigned ID of the order to be cancelled. + + Returns: + str: A message indicating the status of the cancellation. + + Raises: + APIRequestError: If the cancellation request fails. + ValidationError: If no order with given client_order_id is found. + """ + # First get the order by client_order_id to get its ID + order = self.get_by_client_order_id(client_order_id) + + # Then cancel by the actual order ID + return self.cancel_by_id(order.id) + @staticmethod def check_for_order_errors( symbol: str, @@ -125,8 +245,11 @@ def check_for_order_errors( if not (qty or notional) or (qty and notional): raise ValueError() - if (take_profit and not stop_loss) or (stop_loss and not take_profit): - raise ValueError() + # Note: This validation was removed because different order classes have different requirements: + # - Bracket orders need both take_profit and stop_loss + # - OTO orders need EITHER take_profit OR stop_loss + # - OCO orders have other specific requirements + # The API will validate based on order_class if ( take_profit @@ -148,6 +271,8 @@ def market( side: str = "buy", time_in_force: str = "day", extended_hours: bool = False, + client_order_id: str | None = None, + order_class: str | None = None, ) -> OrderModel: """Submits a market order for a specified symbol. @@ -164,6 +289,8 @@ def market( (day/gtc/opg/ioc/fok). Defaults to "day". extended_hours (bool, optional): Whether to trade during extended hours. Defaults to False. + client_order_id (str, optional): Client-assigned ID for the order. Defaults to None. + order_class (str, optional): Order class (simple/bracket/oco/oto). Defaults to None. Returns: OrderModel: An instance of the OrderModel representing the submitted order. @@ -190,6 +317,8 @@ def market( entry_type="market", time_in_force=time_in_force, extended_hours=extended_hours, + client_order_id=client_order_id, + order_class=order_class, ) ######################################################## @@ -206,6 +335,8 @@ def limit( side: str = "buy", time_in_force: str = "day", extended_hours: bool = False, + client_order_id: str | None = None, + order_class: str | None = None, ) -> OrderModel: """Limit order function that submits an order to buy or sell a specified symbol at a specified limit price. @@ -226,6 +357,8 @@ def limit( or "gtc" (good till canceled). Default is "day". extended_hours (bool, optional): Whether to allow trading during extended hours. Default is False. + client_order_id (str, optional): Client-assigned ID for the order. Defaults to None. + order_class (str, optional): Order class (simple/bracket/oco/oto). Defaults to None. Returns: OrderModel: The submitted order. @@ -253,6 +386,8 @@ def limit( entry_type="limit", time_in_force=time_in_force, extended_hours=extended_hours, + client_order_id=client_order_id, + order_class=order_class, ) ######################################################## @@ -268,6 +403,8 @@ def stop( stop_loss: float | None = None, time_in_force: str = "day", extended_hours: bool = False, + client_order_id: str | None = None, + order_class: str | None = None, ) -> OrderModel: """Args: @@ -283,6 +420,8 @@ def stop( Defaults to 'day'. extended_hours: A boolean value indicating whether to place the order during extended hours. Defaults to False. + client_order_id: Client-assigned ID for the order. Defaults to None. + order_class: Order class (simple/bracket/oco/oto). Defaults to None. Returns: An instance of the OrderModel representing the submitted order. @@ -311,6 +450,8 @@ def stop( entry_type="stop", time_in_force=time_in_force, extended_hours=extended_hours, + client_order_id=client_order_id, + order_class=order_class, ) ######################################################## @@ -325,6 +466,8 @@ def stop_limit( side: str = "buy", time_in_force: str = "day", extended_hours: bool = False, + client_order_id: str | None = None, + order_class: str | None = None, ) -> OrderModel: """Submits a stop-limit order for trading. @@ -339,6 +482,8 @@ def stop_limit( Defaults to 'day'. extended_hours (bool, optional): Whether to allow trading during extended hours. Defaults to False. + client_order_id (str, optional): Client-assigned ID for the order. Defaults to None. + order_class (str, optional): Order class (simple/bracket/oco/oto). Defaults to None. Returns: OrderModel: The submitted stop-limit order. @@ -366,6 +511,8 @@ def stop_limit( entry_type="stop_limit", time_in_force=time_in_force, extended_hours=extended_hours, + client_order_id=client_order_id, + order_class=order_class, ) ######################################################## @@ -380,6 +527,8 @@ def trailing_stop( side: str = "buy", time_in_force: str = "day", extended_hours: bool = False, + client_order_id: str | None = None, + order_class: str | None = None, ) -> OrderModel: """Submits a trailing stop order for the specified symbol. @@ -392,7 +541,10 @@ def trailing_stop( `trail_percent` or `trail_price` must be provided, not both. Defaults to None. side (str, optional): The side of the order, either 'buy' or 'sell'. Defaults to 'buy'. time_in_force (str, optional): The time in force for the order. Defaults to 'day'. - extended_hours (bool, optional): Whether to allow trading during extended hours.\n Defaults to False. + extended_hours (bool, optional): Whether to allow trading during extended hours. + Defaults to False. + client_order_id (str, optional): Client-assigned ID for the order. Defaults to None. + order_class (str, optional): Order class (simple/bracket/oco/oto). Defaults to None. Returns: OrderModel: The submitted trailing stop order. @@ -426,6 +578,8 @@ def trailing_stop( entry_type="trailing_stop", time_in_force=time_in_force, extended_hours=extended_hours, + client_order_id=client_order_id, + order_class=order_class, ) ######################################################## @@ -446,6 +600,8 @@ def _submit_order( side: str = "buy", time_in_force: str = "day", extended_hours: bool = False, + client_order_id: str | None = None, + order_class: str | None = None, ) -> OrderModel: """Submits an order to the Alpaca API. @@ -470,7 +626,10 @@ def _submit_order( side (str, optional): The side of the trade (buy or sell). Defaults to "buy". time_in_force (str, optional): The time in force for the order. Defaults to "day". - extended_hours (bool, optional): Whether to allow trading during extended hours.\n Defaults to False. + extended_hours (bool, optional): Whether to allow trading during extended hours. + Defaults to False. + client_order_id (str, optional): Client-assigned ID for the order. Defaults to None. + order_class (str, optional): Order class (simple/bracket/oco/oto). Defaults to None. Returns: OrderModel: The submitted order. @@ -478,6 +637,17 @@ def _submit_order( Raises: Exception: If the order submission fails. """ + # Determine order class + if order_class: + # Use explicitly provided order class + final_order_class = order_class + elif take_profit or stop_loss: + # Bracket order if take profit or stop loss is specified + final_order_class = "bracket" + else: + # Default to simple + final_order_class = "simple" + payload = { "symbol": symbol, "qty": qty if qty else None, @@ -486,13 +656,14 @@ def _submit_order( "limit_price": limit_price if limit_price else None, "trail_percent": trail_percent if trail_percent else None, "trail_price": trail_price if trail_price else None, - "order_class": "bracket" if take_profit or stop_loss else "simple", + "order_class": final_order_class, "take_profit": take_profit, "stop_loss": stop_loss, "side": side if side == "buy" else "sell", "type": entry_type, "time_in_force": time_in_force, "extended_hours": extended_hours, + "client_order_id": client_order_id if client_order_id else None, } url = f"{self.base_url}/orders" diff --git a/tests/test_cache/test_cache_integration.py b/tests/test_cache/test_cache_integration.py new file mode 100644 index 0000000..76b8adb --- /dev/null +++ b/tests/test_cache/test_cache_integration.py @@ -0,0 +1,301 @@ +"""Integration tests for cache system.""" + +from __future__ import annotations + +import os +import time +from unittest.mock import patch + +import pytest + +from py_alpaca_api import PyAlpacaAPI +from py_alpaca_api.cache import CacheConfig, CacheManager, CacheType + + +@pytest.fixture +def cache_manager(): + """Create cache manager for testing.""" + config = CacheConfig( + cache_type=CacheType.MEMORY, + max_size=100, + default_ttl=60, + ) + return CacheManager(config) + + +@pytest.fixture +def alpaca(): + """Create PyAlpacaAPI client for testing.""" + api_key = os.getenv("ALPACA_API_KEY", "test_key") + api_secret = os.getenv("ALPACA_SECRET_KEY", "test_secret") + + return PyAlpacaAPI( + api_key=api_key, + api_secret=api_secret, + api_paper=True, + ) + + +class TestCacheIntegration: + """Integration tests for cache system.""" + + def test_cache_with_bars_data(self, cache_manager): + """Test caching bar data.""" + # Simulate bar data + bars_data = { + "symbol": "AAPL", + "bars": [ + { + "t": "2024-01-01T00:00:00Z", + "o": 100, + "h": 105, + "l": 99, + "c": 103, + "v": 1000, + }, + { + "t": "2024-01-02T00:00:00Z", + "o": 103, + "h": 108, + "l": 102, + "c": 106, + "v": 1200, + }, + ], + } + + # Generate cache key + cache_key = cache_manager.generate_key( + "bars", + symbol="AAPL", + start="2024-01-01", + end="2024-01-02", + timeframe="1d", + ) + + # Set in cache + cache_manager.set(cache_key, bars_data, "bars") + + # Get from cache + cached_data = cache_manager.get(cache_key, "bars") + + assert cached_data == bars_data + assert cache_manager._hit_count == 1 + + def test_cache_with_quotes_data(self, cache_manager): + """Test caching quote data.""" + quote_data = { + "symbol": "AAPL", + "bid": 150.25, + "ask": 150.30, + "bid_size": 100, + "ask_size": 200, + "timestamp": "2024-01-01T10:30:00Z", + } + + cache_key = cache_manager.generate_key("quotes", symbol="AAPL") + cache_manager.set(cache_key, quote_data, "quotes", ttl=1) # 1 second TTL + + # Immediate get should work + assert cache_manager.get(cache_key) == quote_data + + # After expiry should return None + time.sleep(1.1) + assert cache_manager.get(cache_key) is None + + def test_cache_with_market_hours(self, cache_manager): + """Test caching market hours data.""" + market_hours = { + "date": "2024-01-01", + "open": "09:30", + "close": "16:00", + "is_open": True, + } + + cache_key = cache_manager.generate_key("market_hours", date="2024-01-01") + cache_manager.set(cache_key, market_hours, "market_hours") + + # Should have 1 day TTL + cached_data = cache_manager.get(cache_key) + assert cached_data == market_hours + + # Check TTL is set correctly (86400 seconds) + assert cache_manager.config.get_ttl("market_hours") == 86400 + + def test_cache_invalidation_on_symbol(self, cache_manager): + """Test invalidating cache for specific symbol.""" + # Add multiple entries + cache_manager._cache.set("bars:AAPL:1d", {"data": "aapl_daily"}, 60) + cache_manager._cache.set("bars:AAPL:1h", {"data": "aapl_hourly"}, 60) + cache_manager._cache.set("quotes:AAPL", {"data": "aapl_quote"}, 60) + cache_manager._cache.set("bars:GOOGL:1d", {"data": "googl_daily"}, 60) + + # Invalidate all AAPL data + count = cache_manager.invalidate_pattern("*AAPL*") + + assert count == 3 + assert cache_manager._cache.get("bars:GOOGL:1d") == {"data": "googl_daily"} + + def test_cache_size_limit(self, cache_manager): + """Test cache size limit enforcement.""" + cache_manager.config.max_size = 5 + cache_manager._cache.max_size = 5 + + # Add more items than max size + for i in range(10): + key = cache_manager.generate_key("test", id=i) + cache_manager.set(key, f"value_{i}", "test") + + # Should only have 5 items + assert cache_manager._cache.size() == 5 + + # Latest items should be present + for i in range(5, 10): + key = cache_manager.generate_key("test", id=i) + assert cache_manager.get(key) is not None + + # Oldest items should be evicted + for i in range(0, 5): + key = cache_manager.generate_key("test", id=i) + assert cache_manager.get(key) is None + + def test_cache_decorator_with_api_call(self, cache_manager): + """Test cached decorator with simulated API call.""" + api_call_count = 0 + + @cache_manager.cached("assets", ttl=3600) + def get_asset(symbol: str) -> dict: + nonlocal api_call_count + api_call_count += 1 + # Simulate API call + return {"symbol": symbol, "name": f"{symbol} Company", "exchange": "NASDAQ"} + + # First call should make API call + result1 = get_asset("AAPL") + assert result1["symbol"] == "AAPL" + assert api_call_count == 1 + + # Second call should use cache + result2 = get_asset("AAPL") + assert result2 == result1 + assert api_call_count == 1 # No additional API call + + # Different symbol should make API call + result3 = get_asset("GOOGL") + assert result3["symbol"] == "GOOGL" + assert api_call_count == 2 + + def test_concurrent_cache_access(self, cache_manager): + """Test concurrent access to cache.""" + import threading + + results = [] + errors = [] + + def cache_operation(thread_id: int): + try: + # Each thread sets and gets its own key + key = cache_manager.generate_key("thread", id=thread_id) + cache_manager.set(key, f"value_{thread_id}", "test") + time.sleep(0.01) # Small delay + value = cache_manager.get(key) + results.append(value == f"value_{thread_id}") + except Exception as e: + errors.append(e) + + # Create multiple threads + threads = [] + for i in range(10): + thread = threading.Thread(target=cache_operation, args=(i,)) + threads.append(thread) + thread.start() + + # Wait for all threads to complete + for thread in threads: + thread.join() + + # All operations should succeed + assert len(errors) == 0 + assert all(results) + + def test_cache_stats_accuracy(self, cache_manager): + """Test accuracy of cache statistics.""" + # Perform various operations + cache_manager.set("key1", "value1", "test") + cache_manager.set("key2", "value2", "test") + + _ = cache_manager.get("key1") # Hit + _ = cache_manager.get("key2") # Hit + _ = cache_manager.get("key3") # Miss + _ = cache_manager.get("key4") # Miss + + cache_manager.delete("key1") + + stats = cache_manager.get_stats() + + assert stats["size"] == 1 # Only key2 remains + assert stats["hit_count"] == 2 + assert stats["miss_count"] == 2 + assert stats["hit_rate"] == 0.5 + assert stats["total_requests"] == 4 + + def test_cache_clear_by_data_type(self, cache_manager): + """Test clearing cache by data type prefix.""" + # Add items of different types + bars_key = cache_manager.generate_key("bars", symbol="AAPL") + quotes_key = cache_manager.generate_key("quotes", symbol="AAPL") + trades_key = cache_manager.generate_key("trades", symbol="AAPL") + + cache_manager.set(bars_key, "bars_data", "bars") + cache_manager.set(quotes_key, "quotes_data", "quotes") + cache_manager.set(trades_key, "trades_data", "trades") + + # Clear only bars + count = cache_manager.clear("bars") + + assert count == 1 + assert cache_manager.get(bars_key) is None + assert cache_manager.get(quotes_key) == "quotes_data" + assert cache_manager.get(trades_key) == "trades_data" + + def test_cache_memory_efficiency(self, cache_manager): + """Test memory efficiency with large datasets.""" + # Create a large dataset + large_data = { + "symbol": "AAPL", + "bars": [ + {"t": f"2024-01-{i:02d}", "o": 100 + i, "c": 100 + i + 1} + for i in range(1, 32) # 31 days of data + ], + } + + key = cache_manager.generate_key("bars", symbol="AAPL", month="2024-01") + cache_manager.set(key, large_data, "bars") + + # Should be able to retrieve + cached = cache_manager.get(key) + assert cached == large_data + assert len(cached["bars"]) == 31 + + def test_redis_cache_simulation(self): + """Test Redis cache configuration (simulated).""" + from py_alpaca_api.cache.cache_manager import LRUCache, RedisCache + + config = CacheConfig( + cache_type=CacheType.REDIS, + redis_host="localhost", + redis_port=6379, + redis_password="test_password", + ) + + # Mock the RedisCache to simulate unavailable Redis server + with patch.object( + RedisCache, "_get_client", side_effect=Exception("Redis not available") + ): + # Should fall back to memory cache gracefully + manager = CacheManager(config) + assert isinstance(manager._cache, LRUCache) + + # Should still work with memory cache + manager.set("key1", "value1", "test") + assert manager.get("key1") == "value1" diff --git a/tests/test_cache/test_cache_manager.py b/tests/test_cache/test_cache_manager.py new file mode 100644 index 0000000..2b345de --- /dev/null +++ b/tests/test_cache/test_cache_manager.py @@ -0,0 +1,381 @@ +"""Tests for cache manager.""" + +from __future__ import annotations + +import time +from dataclasses import dataclass +from unittest.mock import patch + +from py_alpaca_api.cache import CacheConfig, CacheManager, CacheType +from py_alpaca_api.cache.cache_manager import LRUCache, RedisCache + + +class TestLRUCache: + """Test LRU cache implementation.""" + + def test_init(self): + """Test LRU cache initialization.""" + cache = LRUCache(max_size=100) + assert cache.max_size == 100 + assert cache.size() == 0 + + def test_set_and_get(self): + """Test setting and getting items.""" + cache = LRUCache() + cache.set("key1", "value1", ttl=60) + + assert cache.get("key1") == "value1" + assert cache.size() == 1 + + def test_expiry(self): + """Test item expiry.""" + cache = LRUCache() + cache.set("key1", "value1", ttl=0) # Already expired + + time.sleep(0.01) # Small delay to ensure expiry + assert cache.get("key1") is None + assert cache.size() == 0 + + def test_lru_eviction(self): + """Test LRU eviction when cache is full.""" + cache = LRUCache(max_size=3) + + cache.set("key1", "value1", ttl=60) + cache.set("key2", "value2", ttl=60) + cache.set("key3", "value3", ttl=60) + + assert cache.size() == 3 + + # Add another item, should evict key1 + cache.set("key4", "value4", ttl=60) + + assert cache.size() == 3 + assert cache.get("key1") is None + assert cache.get("key2") == "value2" + assert cache.get("key3") == "value3" + assert cache.get("key4") == "value4" + + def test_lru_order(self): + """Test LRU ordering on access.""" + cache = LRUCache(max_size=3) + + cache.set("key1", "value1", ttl=60) + cache.set("key2", "value2", ttl=60) + cache.set("key3", "value3", ttl=60) + + # Access key1 to make it recently used + _ = cache.get("key1") + + # Add key4, should evict key2 (least recently used) + cache.set("key4", "value4", ttl=60) + + assert cache.get("key1") == "value1" + assert cache.get("key2") is None # Evicted + assert cache.get("key3") == "value3" + assert cache.get("key4") == "value4" + + def test_delete(self): + """Test deleting items.""" + cache = LRUCache() + cache.set("key1", "value1", ttl=60) + + assert cache.delete("key1") is True + assert cache.get("key1") is None + assert cache.delete("key1") is False # Already deleted + + def test_clear(self): + """Test clearing cache.""" + cache = LRUCache() + cache.set("key1", "value1", ttl=60) + cache.set("key2", "value2", ttl=60) + + cache.clear() + assert cache.size() == 0 + assert cache.get("key1") is None + assert cache.get("key2") is None + + def test_cleanup_expired(self): + """Test cleaning up expired items.""" + cache = LRUCache() + cache.set("key1", "value1", ttl=0) # Already expired + cache.set("key2", "value2", ttl=60) # Not expired + + time.sleep(0.01) # Small delay to ensure expiry + removed = cache.cleanup_expired() + + assert removed == 1 + assert cache.size() == 1 + assert cache.get("key2") == "value2" + + +class TestCacheConfig: + """Test cache configuration.""" + + def test_default_config(self): + """Test default configuration.""" + config = CacheConfig() + + assert config.cache_type == CacheType.MEMORY + assert config.max_size == 1000 + assert config.default_ttl == 300 + assert config.enabled is True + + def test_custom_config(self): + """Test custom configuration.""" + config = CacheConfig( + cache_type=CacheType.REDIS, + max_size=500, + default_ttl=600, + enabled=False, + ) + + assert config.cache_type == CacheType.REDIS + assert config.max_size == 500 + assert config.default_ttl == 600 + assert config.enabled is False + + def test_get_ttl(self): + """Test getting TTL for data types.""" + config = CacheConfig() + + assert config.get_ttl("market_hours") == 86400 + assert config.get_ttl("positions") == 10 + assert config.get_ttl("unknown") == 300 # Default TTL + + def test_custom_data_ttls(self): + """Test custom data TTLs.""" + config = CacheConfig(data_ttls={"custom_type": 120}) + + assert config.get_ttl("custom_type") == 120 + assert config.get_ttl("unknown") == 300 # Default TTL + + +class TestCacheManager: + """Test cache manager.""" + + def test_init_default(self): + """Test default initialization.""" + manager = CacheManager() + + assert manager.config.cache_type == CacheType.MEMORY + assert manager.config.enabled is True + assert isinstance(manager._cache, LRUCache) + + def test_init_disabled(self): + """Test disabled cache.""" + config = CacheConfig(enabled=False) + manager = CacheManager(config) + + # Should still work but not store anything + manager.set("key1", "value1", "test") + assert manager.get("key1") is None + + def test_generate_key(self): + """Test cache key generation.""" + manager = CacheManager() + + key1 = manager.generate_key("bars", symbol="AAPL", timeframe="1d") + key2 = manager.generate_key("bars", symbol="AAPL", timeframe="1d") + key3 = manager.generate_key("bars", symbol="GOOGL", timeframe="1d") + + assert key1 == key2 # Same parameters + assert key1 != key3 # Different parameters + + def test_generate_key_long(self): + """Test cache key generation with long parameters.""" + manager = CacheManager() + + long_value = "x" * 200 + key = manager.generate_key("test", value=long_value) + + # Should use hash for long keys + assert len(key) < 100 + assert ":" in key + + def test_get_and_set(self): + """Test getting and setting cache items.""" + manager = CacheManager() + + manager.set("key1", {"data": "value"}, "test", ttl=60) + value = manager.get("key1", "test") + + assert value == {"data": "value"} + assert manager._hit_count == 1 + assert manager._miss_count == 0 + + def test_cache_miss(self): + """Test cache miss.""" + manager = CacheManager() + + value = manager.get("nonexistent", "test") + + assert value is None + assert manager._hit_count == 0 + assert manager._miss_count == 1 + + def test_dataclass_serialization(self): + """Test caching dataclass objects.""" + + @dataclass + class TestModel: + id: int + name: str + + manager = CacheManager() + model = TestModel(id=1, name="test") + + manager.set("key1", model, "test") + value = manager.get("key1") + + assert value == {"id": 1, "name": "test"} + + def test_list_of_dataclasses(self): + """Test caching list of dataclass objects.""" + + @dataclass + class TestModel: + id: int + + manager = CacheManager() + models = [TestModel(id=1), TestModel(id=2)] + + manager.set("key1", models, "test") + value = manager.get("key1") + + assert value == [{"id": 1}, {"id": 2}] + + def test_delete(self): + """Test deleting cache items.""" + manager = CacheManager() + + manager.set("key1", "value1", "test") + assert manager.delete("key1") is True + assert manager.get("key1") is None + assert manager.delete("key1") is False + + def test_clear_all(self): + """Test clearing entire cache.""" + manager = CacheManager() + + manager.set("key1", "value1", "test") + manager.set("key2", "value2", "test") + + count = manager.clear() + + assert count == 2 + assert manager.get("key1") is None + assert manager.get("key2") is None + + def test_clear_prefix(self): + """Test clearing cache by prefix.""" + manager = CacheManager() + + # Generate proper keys + bars_key1 = manager.generate_key("bars", key="key1") + bars_key2 = manager.generate_key("bars", key="key2") + quotes_key1 = manager.generate_key("quotes", key="key1") + + manager.set(bars_key1, "value1", "bars") + manager.set(bars_key2, "value2", "bars") + manager.set(quotes_key1, "value3", "quotes") + + count = manager.clear("bars") + + assert count == 2 + assert manager.get(bars_key1) is None + assert manager.get(bars_key2) is None + assert manager.get(quotes_key1) == "value3" + + def test_invalidate_pattern(self): + """Test invalidating by pattern.""" + manager = CacheManager() + + manager._cache.set("bars:AAPL:1d", "value1", 60) + manager._cache.set("bars:AAPL:1h", "value2", 60) + manager._cache.set("bars:GOOGL:1d", "value3", 60) + + count = manager.invalidate_pattern("bars:AAPL*") + + assert count == 2 + assert manager._cache.get("bars:AAPL:1d") is None + assert manager._cache.get("bars:AAPL:1h") is None + assert manager._cache.get("bars:GOOGL:1d") == "value3" + + def test_get_stats(self): + """Test getting cache statistics.""" + manager = CacheManager() + + manager.set("key1", "value1", "test") + _ = manager.get("key1") # Hit + _ = manager.get("key2") # Miss + + stats = manager.get_stats() + + assert stats["enabled"] is True + assert stats["type"] == "memory" + assert stats["size"] == 1 + assert stats["hit_count"] == 1 + assert stats["miss_count"] == 1 + assert stats["hit_rate"] == 0.5 + assert stats["total_requests"] == 2 + + def test_reset_stats(self): + """Test resetting statistics.""" + manager = CacheManager() + + _ = manager.get("key1") # Miss + manager.reset_stats() + + stats = manager.get_stats() + + assert stats["hit_count"] == 0 + assert stats["miss_count"] == 0 + + def test_cached_decorator(self): + """Test cached decorator.""" + manager = CacheManager() + + call_count = 0 + + @manager.cached("test", ttl=60) + def expensive_function(x: int) -> int: + nonlocal call_count + call_count += 1 + return x * 2 + + # First call should execute function + result1 = expensive_function(5) + assert result1 == 10 + assert call_count == 1 + + # Second call should use cache + result2 = expensive_function(5) + assert result2 == 10 + assert call_count == 1 # Not incremented + + # Different argument should execute function + result3 = expensive_function(10) + assert result3 == 20 + assert call_count == 2 + + def test_redis_fallback(self): + """Test fallback to memory cache when Redis unavailable.""" + config = CacheConfig(cache_type=CacheType.REDIS) + + # Mock the RedisCache._get_client to simulate Redis unavailable + with patch.object( + RedisCache, "_get_client", side_effect=Exception("Connection failed") + ): + manager = CacheManager(config) + + # Should fall back to memory cache + assert isinstance(manager._cache, LRUCache) + + def test_disabled_cache(self): + """Test disabled cache type.""" + config = CacheConfig(cache_type=CacheType.DISABLED) + manager = CacheManager(config) + + manager.set("key1", "value1", "test") + assert manager.get("key1") is None + assert manager._cache.size() == 0 diff --git a/tests/test_http/test_feed_manager.py b/tests/test_http/test_feed_manager.py new file mode 100644 index 0000000..f6e6e5f --- /dev/null +++ b/tests/test_http/test_feed_manager.py @@ -0,0 +1,397 @@ +from __future__ import annotations + +from unittest.mock import Mock + +import pytest + +from py_alpaca_api.exceptions import APIRequestError, ValidationError +from py_alpaca_api.http.feed_manager import ( + FeedConfig, + FeedManager, + FeedType, + SubscriptionLevel, +) + + +class TestFeedType: + """Test FeedType enum functionality.""" + + def test_feed_type_values(self): + """Test feed type enum values.""" + assert FeedType.SIP.value == "sip" + assert FeedType.IEX.value == "iex" + assert FeedType.OTC.value == "otc" + + def test_from_string_valid(self): + """Test creating FeedType from valid string.""" + assert FeedType.from_string("sip") == FeedType.SIP + assert FeedType.from_string("SIP") == FeedType.SIP + assert FeedType.from_string("iex") == FeedType.IEX + assert FeedType.from_string("otc") == FeedType.OTC + + def test_from_string_invalid(self): + """Test creating FeedType from invalid string.""" + with pytest.raises(ValidationError) as exc_info: + FeedType.from_string("invalid") + assert "Invalid feed type: invalid" in str(exc_info.value) + + +class TestSubscriptionLevel: + """Test SubscriptionLevel enum functionality.""" + + def test_subscription_level_values(self): + """Test subscription level enum values.""" + assert SubscriptionLevel.BASIC.value == "basic" + assert SubscriptionLevel.UNLIMITED.value == "unlimited" + assert SubscriptionLevel.BUSINESS.value == "business" + + def test_from_error_basic(self): + """Test detecting basic subscription from error.""" + error = "subscription does not permit SIP feed" + assert SubscriptionLevel.from_error(error) == SubscriptionLevel.BASIC + + def test_from_error_unlimited(self): + """Test detecting unlimited subscription from error.""" + error = "requires unlimited subscription" + assert SubscriptionLevel.from_error(error) == SubscriptionLevel.UNLIMITED + + def test_from_error_business(self): + """Test detecting business subscription from error.""" + error = "business subscription required" + assert SubscriptionLevel.from_error(error) == SubscriptionLevel.UNLIMITED + + def test_from_error_no_match(self): + """Test no subscription level detected from error.""" + error = "generic error message" + assert SubscriptionLevel.from_error(error) is None + + +class TestFeedConfig: + """Test FeedConfig dataclass.""" + + def test_default_config(self): + """Test default feed configuration.""" + config = FeedConfig() + assert config.preferred_feed == FeedType.SIP + assert config.fallback_feeds == [FeedType.IEX] + assert config.auto_fallback is True + assert config.subscription_level is None + assert config.endpoint_feeds == {} + + def test_custom_config(self): + """Test custom feed configuration.""" + config = FeedConfig( + preferred_feed=FeedType.IEX, + fallback_feeds=[FeedType.OTC, FeedType.SIP], + auto_fallback=False, + subscription_level=SubscriptionLevel.UNLIMITED, + ) + assert config.preferred_feed == FeedType.IEX + assert config.fallback_feeds == [FeedType.OTC, FeedType.SIP] + assert config.auto_fallback is False + assert config.subscription_level == SubscriptionLevel.UNLIMITED + + def test_get_feed_for_endpoint(self): + """Test getting feed for specific endpoint.""" + config = FeedConfig( + preferred_feed=FeedType.SIP, + endpoint_feeds={"quotes": FeedType.IEX, "trades": FeedType.OTC}, + ) + assert config.get_feed_for_endpoint("quotes") == FeedType.IEX + assert config.get_feed_for_endpoint("trades") == FeedType.OTC + assert config.get_feed_for_endpoint("bars") == FeedType.SIP + + +class TestFeedManager: + """Test FeedManager class.""" + + def test_init_default(self): + """Test default initialization.""" + manager = FeedManager() + assert manager.config.preferred_feed == FeedType.SIP + assert manager._failed_feeds == {} + assert manager._detected_subscription_level is None + + def test_init_with_config(self): + """Test initialization with custom config.""" + config = FeedConfig(preferred_feed=FeedType.IEX) + manager = FeedManager(config) + assert manager.config.preferred_feed == FeedType.IEX + + def test_get_feed_supported_endpoint(self): + """Test getting feed for supported endpoint.""" + manager = FeedManager() + + # Test supported endpoints + assert manager.get_feed("bars") == "sip" + assert manager.get_feed("latest/quotes") == "sip" + assert manager.get_feed("trades") == "sip" + assert manager.get_feed("snapshots") == "sip" + + def test_get_feed_unsupported_endpoint(self): + """Test getting feed for unsupported endpoint.""" + manager = FeedManager() + + # Unsupported endpoints should return None + assert manager.get_feed("account") is None + assert manager.get_feed("positions") is None + assert manager.get_feed("orders") is None + + def test_get_feed_with_endpoint_config(self): + """Test getting feed with endpoint-specific configuration.""" + config = FeedConfig( + preferred_feed=FeedType.SIP, + endpoint_feeds={"quotes": FeedType.IEX}, + ) + manager = FeedManager(config) + + assert manager.get_feed("quotes") == "iex" + assert manager.get_feed("bars") == "sip" + + def test_get_feed_with_failed_feed(self): + """Test getting feed when preferred feed has failed.""" + config = FeedConfig( + preferred_feed=FeedType.SIP, + fallback_feeds=[FeedType.IEX, FeedType.OTC], + ) + manager = FeedManager(config) + + # Mark SIP as failed for bars endpoint + manager._failed_feeds["bars"] = {FeedType.SIP} + + # Should fallback to IEX + assert manager.get_feed("bars") == "iex" + + # Mark IEX as also failed + manager._failed_feeds["bars"].add(FeedType.IEX) + + # Should fallback to OTC + assert manager.get_feed("bars") == "otc" + + def test_handle_feed_error_no_auto_fallback(self): + """Test handling feed error with auto_fallback disabled.""" + config = FeedConfig(auto_fallback=False) + manager = FeedManager(config) + + error = APIRequestError(403, "Access denied") + result = manager.handle_feed_error("bars", "sip", error) + + assert result is None + assert "bars" not in manager._failed_feeds + + def test_handle_feed_error_with_fallback(self): + """Test handling feed error with fallback.""" + config = FeedConfig( + preferred_feed=FeedType.SIP, + fallback_feeds=[FeedType.IEX, FeedType.OTC], + ) + manager = FeedManager(config) + + error = APIRequestError(403, "subscription does not permit SIP") + result = manager.handle_feed_error("bars", "sip", error) + + # Should return IEX as fallback + assert result == "iex" + assert FeedType.SIP in manager._failed_feeds["bars"] + assert manager._detected_subscription_level == SubscriptionLevel.BASIC + + def test_handle_feed_error_with_symbol(self): + """Test handling feed error with symbol tracking.""" + manager = FeedManager() + + error = APIRequestError(403, "Access denied") + result = manager.handle_feed_error("bars", "sip", error, symbol="AAPL") + + # Should track failure with symbol + assert result == "iex" + assert FeedType.SIP in manager._failed_feeds["bars:AAPL"] + + def test_handle_feed_error_no_alternatives(self): + """Test handling feed error with no alternatives.""" + config = FeedConfig( + preferred_feed=FeedType.SIP, + fallback_feeds=[FeedType.IEX], + ) + manager = FeedManager(config) + + # Mark all feeds as failed + manager._failed_feeds["bars"] = {FeedType.SIP, FeedType.IEX} + + error = APIRequestError(403, "Access denied") + result = manager.handle_feed_error("bars", "sip", error) + + assert result is None + + def test_detect_subscription_level_unlimited(self): + """Test detecting unlimited subscription level.""" + manager = FeedManager() + + # Mock successful API response + mock_client = Mock() + mock_client._make_request.return_value = Mock(status_code=200) + + level = manager.detect_subscription_level(mock_client) + + assert level == SubscriptionLevel.UNLIMITED + assert manager._detected_subscription_level == SubscriptionLevel.UNLIMITED + assert manager.config.subscription_level == SubscriptionLevel.UNLIMITED + + def test_detect_subscription_level_basic(self): + """Test detecting basic subscription level.""" + manager = FeedManager() + + # Mock failed API response + mock_client = Mock() + mock_client._make_request.side_effect = APIRequestError( + 403, "subscription does not permit SIP" + ) + + level = manager.detect_subscription_level(mock_client) + + assert level == SubscriptionLevel.BASIC + assert manager._detected_subscription_level == SubscriptionLevel.BASIC + assert manager.config.subscription_level == SubscriptionLevel.BASIC + + def test_detect_subscription_level_unknown_error(self): + """Test detecting subscription level with unknown error.""" + manager = FeedManager() + + # Mock unexpected error + mock_client = Mock() + mock_client._make_request.side_effect = APIRequestError(500, "Server error") + + level = manager.detect_subscription_level(mock_client) + + # Should default to BASIC for safety + assert level == SubscriptionLevel.BASIC + assert manager._detected_subscription_level == SubscriptionLevel.BASIC + + def test_validate_feed_supported_endpoint(self): + """Test validating feed for supported endpoint.""" + manager = FeedManager() + + assert manager.validate_feed("bars", "sip") is True + assert manager.validate_feed("bars", "iex") is True + assert manager.validate_feed("bars", "invalid") is False + + def test_validate_feed_unsupported_endpoint(self): + """Test validating feed for unsupported endpoint.""" + manager = FeedManager() + + assert manager.validate_feed("account", "sip") is False + assert manager.validate_feed("positions", "iex") is False + + def test_validate_feed_with_subscription_level(self): + """Test validating feed with subscription level.""" + config = FeedConfig(subscription_level=SubscriptionLevel.BASIC) + manager = FeedManager(config) + + # Basic can only use IEX + assert manager.validate_feed("bars", "iex") is True + assert manager.validate_feed("bars", "sip") is False + assert manager.validate_feed("bars", "otc") is False + + def test_reset_failures_all(self): + """Test resetting all feed failures.""" + manager = FeedManager() + + # Add some failures + manager._failed_feeds = { + "bars": {FeedType.SIP}, + "bars:AAPL": {FeedType.IEX}, + "quotes": {FeedType.OTC}, + } + + manager.reset_failures() + + assert manager._failed_feeds == {} + + def test_reset_failures_specific_endpoint(self): + """Test resetting failures for specific endpoint.""" + manager = FeedManager() + + # Add some failures + manager._failed_feeds = { + "bars": {FeedType.SIP}, + "bars:AAPL": {FeedType.IEX}, + "bars:MSFT": {FeedType.SIP}, + "quotes": {FeedType.OTC}, + } + + manager.reset_failures("bars") + + # Only bars-related failures should be reset + assert "bars" not in manager._failed_feeds + assert "bars:AAPL" not in manager._failed_feeds + assert "bars:MSFT" not in manager._failed_feeds + assert "quotes" in manager._failed_feeds + + def test_get_available_feeds_unknown_subscription(self): + """Test getting available feeds with unknown subscription.""" + manager = FeedManager() + + feeds = manager.get_available_feeds() + + # Should return all feeds when subscription unknown + assert set(feeds) == {FeedType.SIP, FeedType.IEX, FeedType.OTC} + + def test_get_available_feeds_basic_subscription(self): + """Test getting available feeds with basic subscription.""" + config = FeedConfig(subscription_level=SubscriptionLevel.BASIC) + manager = FeedManager(config) + + feeds = manager.get_available_feeds() + + assert feeds == [FeedType.IEX] + + def test_get_available_feeds_unlimited_subscription(self): + """Test getting available feeds with unlimited subscription.""" + config = FeedConfig(subscription_level=SubscriptionLevel.UNLIMITED) + manager = FeedManager(config) + + feeds = manager.get_available_feeds() + + assert set(feeds) == {FeedType.SIP, FeedType.IEX, FeedType.OTC} + + def test_get_available_feeds_detected_subscription(self): + """Test getting available feeds with detected subscription.""" + manager = FeedManager() + manager._detected_subscription_level = SubscriptionLevel.BASIC + + feeds = manager.get_available_feeds() + + assert feeds == [FeedType.IEX] + + def test_is_feed_available_unknown_subscription(self): + """Test checking feed availability with unknown subscription.""" + manager = FeedManager() + + # All feeds should be available when subscription unknown + assert manager._is_feed_available(FeedType.SIP) is True + assert manager._is_feed_available(FeedType.IEX) is True + assert manager._is_feed_available(FeedType.OTC) is True + + def test_is_feed_available_basic_subscription(self): + """Test checking feed availability with basic subscription.""" + config = FeedConfig(subscription_level=SubscriptionLevel.BASIC) + manager = FeedManager(config) + + assert manager._is_feed_available(FeedType.IEX) is True + assert manager._is_feed_available(FeedType.SIP) is False + assert manager._is_feed_available(FeedType.OTC) is False + + def test_supports_feed_endpoint(self): + """Test checking if endpoint supports feed parameter.""" + manager = FeedManager() + + # Supported endpoints + assert manager._supports_feed("bars") is True + assert manager._supports_feed("/v2/stocks/bars") is True + assert manager._supports_feed("latest/quotes") is True + assert manager._supports_feed("trades") is True + assert manager._supports_feed("snapshots") is True + + # Unsupported endpoints + assert manager._supports_feed("account") is False + assert manager._supports_feed("positions") is False + assert manager._supports_feed("orders") is False diff --git a/tests/test_http/test_feed_manager_integration.py b/tests/test_http/test_feed_manager_integration.py new file mode 100644 index 0000000..9592e4c --- /dev/null +++ b/tests/test_http/test_feed_manager_integration.py @@ -0,0 +1,283 @@ +from __future__ import annotations + +import os +from unittest.mock import Mock + +import pytest + +from py_alpaca_api import PyAlpacaAPI +from py_alpaca_api.exceptions import APIRequestError +from py_alpaca_api.http.feed_manager import ( + FeedConfig, + FeedManager, + FeedType, + SubscriptionLevel, +) + + +@pytest.fixture +def alpaca(): + """Create PyAlpacaAPI client for testing.""" + api_key = os.getenv("ALPACA_API_KEY") + api_secret = os.getenv("ALPACA_SECRET_KEY") + + if not api_key or not api_secret: + pytest.skip("No API credentials found") + + return PyAlpacaAPI( + api_key=api_key, + api_secret=api_secret, + api_paper=True, + ) + + +@pytest.fixture +def feed_manager(): + """Create a feed manager for testing.""" + return FeedManager() + + +class TestFeedManagerIntegration: + """Integration tests for feed manager with live API.""" + + def test_detect_subscription_level_with_live_api(self, alpaca, feed_manager): + """Test detecting subscription level with live API.""" + # Create a mock client that wraps the real client + mock_client = Mock() + + # Try to get a quote with SIP feed to test subscription + try: + _ = alpaca.stock.latest_quote.get("AAPL", feed="sip") + # If we got here, SIP is available + mock_client._make_request.return_value = Mock(status_code=200) + except APIRequestError as e: + # SIP not available + if "subscription" in str(e).lower() or "feed" in str(e).lower(): + mock_client._make_request.side_effect = e + else: + # Different error, skip test + pytest.skip(f"Unexpected error: {e}") + + # Detect subscription level + level = feed_manager.detect_subscription_level(mock_client) + + assert level in [SubscriptionLevel.BASIC, SubscriptionLevel.UNLIMITED] + assert feed_manager._detected_subscription_level == level + + def test_feed_fallback_with_live_api(self, alpaca): + """Test feed fallback behavior with live API.""" + # Try to get quotes with different feeds + feeds_tested = [] + successful_feed = None + + for feed_type in [FeedType.SIP, FeedType.IEX]: + try: + _ = alpaca.stock.latest_quote.get("AAPL", feed=feed_type.value) + feeds_tested.append((feed_type, True)) + successful_feed = feed_type + break + except APIRequestError: + feeds_tested.append((feed_type, False)) + continue + + # At least one feed should work + assert successful_feed is not None, f"No feeds worked: {feeds_tested}" + + def test_feed_manager_with_bars_endpoint(self, alpaca): + """Test feed manager with bars endpoint.""" + manager = FeedManager() + + # Test with bars endpoint + feed = manager.get_feed("bars") + assert feed in ["sip", "iex", "otc"] + + # Try to fetch bars with the suggested feed + try: + bars = alpaca.stock.history.get_stock_data( + symbol="AAPL", + start="2024-01-01", + end="2024-12-31", + timeframe="1d", + limit=10, + feed=feed, + ) + # If successful, feed is appropriate + assert bars is not None + except APIRequestError as e: + # Feed not available, manager should handle this + alternative = manager.handle_feed_error("bars", feed, e, symbol="AAPL") + if alternative: + # Try with alternative feed + bars = alpaca.stock.history.get_stock_data( + symbol="AAPL", + start="2024-01-01", + end="2024-12-31", + timeframe="1d", + limit=10, + feed=alternative, + ) + assert bars is not None + + def test_feed_manager_with_quotes_endpoint(self, alpaca): + """Test feed manager with quotes endpoint.""" + manager = FeedManager() + + # Test with quotes endpoint + feed = manager.get_feed("latest/quotes") + assert feed in ["sip", "iex", "otc"] + + # Try to fetch quote with the suggested feed + try: + quote = alpaca.stock.latest_quote.get("AAPL", feed=feed) + assert quote is not None + except APIRequestError as e: + # Feed not available, manager should handle this + alternative = manager.handle_feed_error( + "latest/quotes", feed, e, symbol="AAPL" + ) + if alternative: + # Try with alternative feed + quote = alpaca.stock.latest_quote.get("AAPL", feed=alternative) + assert quote is not None + + def test_feed_manager_with_trades_endpoint(self, alpaca): + """Test feed manager with trades endpoint.""" + manager = FeedManager() + + # Test with trades endpoint + feed = manager.get_feed("trades") + assert feed in ["sip", "iex", "otc"] + + # Try to fetch trades with the suggested feed + try: + trades = alpaca.stock.trades.get_latest_trade("AAPL", feed=feed) + assert trades is not None + except APIRequestError as e: + # Feed not available, manager should handle this + alternative = manager.handle_feed_error("trades", feed, e, symbol="AAPL") + if alternative: + # Try with alternative feed + trades = alpaca.stock.trades.get_latest_trade("AAPL", feed=alternative) + assert trades is not None + + def test_feed_validation_with_live_data(self, alpaca): + """Test feed validation based on actual API access.""" + manager = FeedManager() + + # Test validation for bars endpoint + assert manager.validate_feed("bars", "iex") is True + assert manager.validate_feed("bars", "invalid_feed") is False + + # Test validation for non-feed endpoint + assert manager.validate_feed("account", "iex") is False + + def test_feed_manager_caching_behavior(self, alpaca): + """Test that feed manager caches failed feeds appropriately.""" + manager = FeedManager( + FeedConfig( + preferred_feed=FeedType.SIP, + fallback_feeds=[FeedType.IEX], + ) + ) + + # First request + feed1 = manager.get_feed("bars", symbol="AAPL") + + # Simulate a failure if using SIP + if feed1 == "sip": + try: + _ = alpaca.stock.history.get_stock_data( + symbol="AAPL", + start="2024-01-01", + end="2024-12-31", + timeframe="1d", + limit=1, + feed=feed1, + ) + except APIRequestError as e: + # Handle the error + alternative = manager.handle_feed_error("bars", feed1, e, symbol="AAPL") + + # Second request should return alternative directly + feed2 = manager.get_feed("bars", symbol="AAPL") + assert feed2 in {alternative, "iex"} + + def test_feed_manager_reset_failures(self): + """Test resetting feed failures.""" + manager = FeedManager() + + # Add some failures + error = APIRequestError(403, "Access denied") + manager.handle_feed_error("bars", "sip", error, symbol="AAPL") + manager.handle_feed_error("quotes", "sip", error) + + assert len(manager._failed_feeds) > 0 + + # Reset all failures + manager.reset_failures() + assert len(manager._failed_feeds) == 0 + + def test_multiple_symbols_with_feed_manager(self, alpaca): + """Test feed manager with multiple symbols.""" + manager = FeedManager() + + symbols = ["AAPL", "GOOGL", "MSFT"] + successful_fetches = [] + + for symbol in symbols: + feed = manager.get_feed("latest/quotes", symbol=symbol) + + try: + _ = alpaca.stock.latest_quote.get(symbol, feed=feed) + successful_fetches.append((symbol, feed, True)) + except APIRequestError as e: + # Try fallback + alternative = manager.handle_feed_error( + "latest/quotes", feed, e, symbol=symbol + ) + if alternative: + try: + _ = alpaca.stock.latest_quote.get(symbol, feed=alternative) + successful_fetches.append((symbol, alternative, True)) + except APIRequestError: + successful_fetches.append((symbol, alternative, False)) + else: + successful_fetches.append((symbol, feed, False)) + + # At least some symbols should succeed + successful_count = sum(1 for _, _, success in successful_fetches if success) + assert ( + successful_count > 0 + ), f"Failed to fetch any symbols: {successful_fetches}" + + def test_feed_config_endpoint_specific(self, alpaca): + """Test endpoint-specific feed configuration.""" + config = FeedConfig( + preferred_feed=FeedType.SIP, + endpoint_feeds={ + "latest/quotes": FeedType.IEX, + "bars": FeedType.SIP, + }, + ) + manager = FeedManager(config) + + # Check that endpoint-specific config is used + assert manager.get_feed("latest/quotes") == "iex" + assert manager.get_feed("bars") == "sip" + assert manager.get_feed("trades") == "sip" # Uses default + + def test_subscription_level_affects_available_feeds(self): + """Test that subscription level affects available feeds.""" + # Test with BASIC subscription + config_basic = FeedConfig(subscription_level=SubscriptionLevel.BASIC) + manager_basic = FeedManager(config_basic) + + available_basic = manager_basic.get_available_feeds() + assert available_basic == [FeedType.IEX] + + # Test with UNLIMITED subscription + config_unlimited = FeedConfig(subscription_level=SubscriptionLevel.UNLIMITED) + manager_unlimited = FeedManager(config_unlimited) + + available_unlimited = manager_unlimited.get_available_feeds() + assert set(available_unlimited) == {FeedType.SIP, FeedType.IEX, FeedType.OTC} diff --git a/tests/test_integration/test_account_config_integration.py b/tests/test_integration/test_account_config_integration.py new file mode 100644 index 0000000..423e7a9 --- /dev/null +++ b/tests/test_integration/test_account_config_integration.py @@ -0,0 +1,183 @@ +import os + +import pytest + +from py_alpaca_api import PyAlpacaAPI +from py_alpaca_api.models.account_config_model import AccountConfigModel + + +@pytest.mark.skipif( + not os.environ.get("ALPACA_API_KEY") or not os.environ.get("ALPACA_SECRET_KEY"), + reason="API credentials not set", +) +class TestAccountConfigIntegration: + @pytest.fixture + def alpaca(self): + return PyAlpacaAPI( + api_key=os.environ.get("ALPACA_API_KEY"), + api_secret=os.environ.get("ALPACA_SECRET_KEY"), + api_paper=True, + ) + + @pytest.fixture + def original_config(self, alpaca): + """Get the original configuration to restore after tests.""" + return alpaca.trading.account.get_configuration() + + def test_get_configuration(self, alpaca): + config = alpaca.trading.account.get_configuration() + + assert isinstance(config, AccountConfigModel) + # These fields should always be present + assert config.dtbp_check in ["entry", "exit", "both"] + assert isinstance(config.fractional_trading, bool) + assert config.max_margin_multiplier in ["1", "2", "4"] + assert isinstance(config.no_shorting, bool) + assert config.pdt_check in ["entry", "exit", "both"] + assert isinstance(config.ptp_no_exception_entry, bool) + assert isinstance(config.suspend_trade, bool) + assert config.trade_confirm_email in ["all", "none"] + + def test_update_single_configuration_param(self, alpaca, original_config): + # Toggle trade confirmation email + new_setting = "none" if original_config.trade_confirm_email == "all" else "all" + + updated_config = alpaca.trading.account.update_configuration( + trade_confirm_email=new_setting + ) + + assert isinstance(updated_config, AccountConfigModel) + assert updated_config.trade_confirm_email == new_setting + + # Verify other settings remain unchanged + assert updated_config.dtbp_check == original_config.dtbp_check + assert updated_config.fractional_trading == original_config.fractional_trading + assert updated_config.no_shorting == original_config.no_shorting + assert updated_config.pdt_check == original_config.pdt_check + assert updated_config.suspend_trade == original_config.suspend_trade + + # Restore original setting + alpaca.trading.account.update_configuration( + trade_confirm_email=original_config.trade_confirm_email + ) + + def test_update_multiple_configuration_params(self, alpaca, original_config): + # Toggle no_shorting and change pdt_check + new_no_shorting = not original_config.no_shorting + new_pdt_check = "exit" if original_config.pdt_check == "entry" else "entry" + + updated_config = alpaca.trading.account.update_configuration( + no_shorting=new_no_shorting, pdt_check=new_pdt_check + ) + + assert isinstance(updated_config, AccountConfigModel) + assert updated_config.no_shorting == new_no_shorting + assert updated_config.pdt_check == new_pdt_check + + # Verify other settings remain unchanged + assert updated_config.dtbp_check == original_config.dtbp_check + assert updated_config.fractional_trading == original_config.fractional_trading + assert ( + updated_config.max_margin_multiplier + == original_config.max_margin_multiplier + ) + assert updated_config.suspend_trade == original_config.suspend_trade + assert updated_config.trade_confirm_email == original_config.trade_confirm_email + + # Restore original settings + alpaca.trading.account.update_configuration( + no_shorting=original_config.no_shorting, + pdt_check=original_config.pdt_check, + ) + + def test_update_margin_multiplier(self, alpaca, original_config): + # Test changing margin multiplier + current_multiplier = original_config.max_margin_multiplier + new_multiplier = "2" if current_multiplier != "2" else "4" + + updated_config = alpaca.trading.account.update_configuration( + max_margin_multiplier=new_multiplier + ) + + assert updated_config.max_margin_multiplier == new_multiplier + + # Restore original + alpaca.trading.account.update_configuration( + max_margin_multiplier=original_config.max_margin_multiplier + ) + + def test_update_dtbp_check(self, alpaca, original_config): + # Cycle through dtbp_check options + options = ["entry", "exit", "both"] + current = original_config.dtbp_check + new_value = options[(options.index(current) + 1) % 3] + + updated_config = alpaca.trading.account.update_configuration( + dtbp_check=new_value + ) + + assert updated_config.dtbp_check == new_value + + # Restore original + alpaca.trading.account.update_configuration( + dtbp_check=original_config.dtbp_check + ) + + def test_toggle_fractional_trading(self, alpaca, original_config): + # Toggle fractional trading + new_value = not original_config.fractional_trading + + updated_config = alpaca.trading.account.update_configuration( + fractional_trading=new_value + ) + + assert updated_config.fractional_trading == new_value + + # Restore original + alpaca.trading.account.update_configuration( + fractional_trading=original_config.fractional_trading + ) + + def test_configuration_persistence(self, alpaca, original_config): + # Update a configuration + new_email_setting = ( + "none" if original_config.trade_confirm_email == "all" else "all" + ) + alpaca.trading.account.update_configuration( + trade_confirm_email=new_email_setting + ) + + # Get configuration again to verify persistence + config = alpaca.trading.account.get_configuration() + assert config.trade_confirm_email == new_email_setting + + # Restore original + alpaca.trading.account.update_configuration( + trade_confirm_email=original_config.trade_confirm_email + ) + + def test_invalid_parameter_handling(self, alpaca): + # Test that invalid parameters raise appropriate errors + with pytest.raises(ValueError): + alpaca.trading.account.update_configuration(dtbp_check="invalid") + + with pytest.raises(ValueError): + alpaca.trading.account.update_configuration(pdt_check="invalid") + + with pytest.raises(ValueError): + alpaca.trading.account.update_configuration(max_margin_multiplier="5") + + with pytest.raises(ValueError): + alpaca.trading.account.update_configuration(trade_confirm_email="sometimes") + + @pytest.mark.skip(reason="Suspend trade affects account functionality") + def test_suspend_trade_toggle(self, alpaca, original_config): + # This test is skipped as it would actually suspend trading + # Only run manually when testing this specific feature + updated_config = alpaca.trading.account.update_configuration(suspend_trade=True) + assert updated_config.suspend_trade is True + + # Immediately restore + alpaca.trading.account.update_configuration( + suspend_trade=original_config.suspend_trade + ) diff --git a/tests/test_integration/test_metadata_integration.py b/tests/test_integration/test_metadata_integration.py new file mode 100644 index 0000000..e489a0b --- /dev/null +++ b/tests/test_integration/test_metadata_integration.py @@ -0,0 +1,195 @@ +import os + +import pytest + +from py_alpaca_api import PyAlpacaAPI + + +@pytest.mark.skipif( + not os.environ.get("ALPACA_API_KEY") or not os.environ.get("ALPACA_SECRET_KEY"), + reason="API credentials not set", +) +class TestMetadataIntegration: + @pytest.fixture + def alpaca(self): + return PyAlpacaAPI( + api_key=os.environ.get("ALPACA_API_KEY"), + api_secret=os.environ.get("ALPACA_SECRET_KEY"), + api_paper=True, + ) + + def test_get_exchange_codes(self, alpaca): + exchanges = alpaca.stock.metadata.get_exchange_codes() + + assert isinstance(exchanges, dict) + assert len(exchanges) > 0 + + # Check for some common exchanges + assert "A" in exchanges # NYSE American + assert "N" in exchanges # NYSE + assert "Q" in exchanges # NASDAQ + assert "V" in exchanges # IEX + assert "P" in exchanges # NYSE Arca + + # Verify values are strings + for code, name in exchanges.items(): + assert isinstance(code, str) + assert isinstance(name, str) + assert len(name) > 0 + + print(f"Found {len(exchanges)} exchange codes") + + def test_get_condition_codes_trade_tape_a(self, alpaca): + conditions = alpaca.stock.metadata.get_condition_codes( + ticktype="trade", tape="A" + ) + + assert isinstance(conditions, dict) + assert len(conditions) > 0 + + # Check for common condition codes + if "" in conditions: + assert conditions[""] == "Regular Sale" + + # Verify all values are strings + for code, description in conditions.items(): + assert isinstance(code, str) + assert isinstance(description, str) + + print(f"Found {len(conditions)} trade conditions for Tape A") + + def test_get_condition_codes_trade_tape_b(self, alpaca): + conditions = alpaca.stock.metadata.get_condition_codes( + ticktype="trade", tape="B" + ) + + assert isinstance(conditions, dict) + assert len(conditions) > 0 + + print(f"Found {len(conditions)} trade conditions for Tape B") + + def test_get_condition_codes_quote(self, alpaca): + conditions = alpaca.stock.metadata.get_condition_codes( + ticktype="quote", tape="A" + ) + + assert isinstance(conditions, dict) + # Quote conditions might be fewer than trade conditions + assert len(conditions) >= 0 + + # Verify all values are strings + for code, description in conditions.items(): + assert isinstance(code, str) + assert isinstance(description, str) + + print(f"Found {len(conditions)} quote conditions for Tape A") + + def test_get_all_condition_codes(self, alpaca): + all_conditions = alpaca.stock.metadata.get_all_condition_codes() + + assert isinstance(all_conditions, dict) + assert "trade" in all_conditions + assert "quote" in all_conditions + + # Check structure + for ticktype in ["trade", "quote"]: + assert ticktype in all_conditions + for tape in ["A", "B", "C"]: + assert tape in all_conditions[ticktype] + assert isinstance(all_conditions[ticktype][tape], dict) + + # Count total conditions + total = 0 + for ticktype in all_conditions: + for tape in all_conditions[ticktype]: + total += len(all_conditions[ticktype][tape]) + + print(f"Found {total} total condition codes across all types and tapes") + + def test_lookup_exchange(self, alpaca): + # Test valid exchange codes + nasdaq = alpaca.stock.metadata.lookup_exchange("Q") + assert nasdaq is not None + assert "NASDAQ" in nasdaq + + nyse = alpaca.stock.metadata.lookup_exchange("N") + assert nyse is not None + assert "New York Stock Exchange" in nyse + + iex = alpaca.stock.metadata.lookup_exchange("V") + assert iex is not None + assert "IEX" in iex + + # Test invalid code + invalid = alpaca.stock.metadata.lookup_exchange("ZZ") + assert invalid is None + + def test_lookup_condition(self, alpaca): + # Test looking up a specific condition + # Empty string is often "Regular Sale" + regular = alpaca.stock.metadata.lookup_condition("", ticktype="trade", tape="A") + if regular: + assert "Regular" in regular or "Sale" in regular + + # Test invalid condition + invalid = alpaca.stock.metadata.lookup_condition( + "ZZ", ticktype="trade", tape="A" + ) + assert invalid is None + + def test_caching_behavior(self, alpaca): + # Clear cache first + alpaca.stock.metadata.clear_cache() + + # First call should hit API + exchanges1 = alpaca.stock.metadata.get_exchange_codes() + + # Second call should use cache (should be faster) + exchanges2 = alpaca.stock.metadata.get_exchange_codes() + + assert exchanges1 == exchanges2 + + # Force API call by disabling cache + exchanges3 = alpaca.stock.metadata.get_exchange_codes(use_cache=False) + + assert exchanges3 == exchanges1 + + def test_clear_cache(self, alpaca): + # Load some data into cache + alpaca.stock.metadata.get_exchange_codes() + alpaca.stock.metadata.get_condition_codes(ticktype="trade", tape="A") + + # Verify cache is populated + assert alpaca.stock.metadata._exchange_cache is not None + assert len(alpaca.stock.metadata._condition_cache) > 0 + + # Clear cache + alpaca.stock.metadata.clear_cache() + + # Verify cache is cleared + assert alpaca.stock.metadata._exchange_cache is None + assert len(alpaca.stock.metadata._condition_cache) == 0 + + def test_different_tapes_have_same_conditions(self, alpaca): + # Get conditions for different tapes + tape_a = alpaca.stock.metadata.get_condition_codes(ticktype="trade", tape="A") + tape_b = alpaca.stock.metadata.get_condition_codes(ticktype="trade", tape="B") + tape_c = alpaca.stock.metadata.get_condition_codes(ticktype="trade", tape="C") + + # Tapes often have similar condition codes + # Check if they have some overlap + common_codes = set(tape_a.keys()) & set(tape_b.keys()) + assert len(common_codes) > 0 + + print(f"Tape A: {len(tape_a)} conditions") + print(f"Tape B: {len(tape_b)} conditions") + print(f"Tape C: {len(tape_c)} conditions") + print(f"Common codes between A and B: {len(common_codes)}") + + def test_exchange_codes_are_consistent(self, alpaca): + # Get exchanges multiple times to ensure consistency + exchanges1 = alpaca.stock.metadata.get_exchange_codes(use_cache=False) + exchanges2 = alpaca.stock.metadata.get_exchange_codes(use_cache=False) + + assert exchanges1 == exchanges2 + assert len(exchanges1) == len(exchanges2) diff --git a/tests/test_integration/test_order_enhancements_integration.py b/tests/test_integration/test_order_enhancements_integration.py new file mode 100644 index 0000000..1e41125 --- /dev/null +++ b/tests/test_integration/test_order_enhancements_integration.py @@ -0,0 +1,282 @@ +import contextlib +import os +import time + +import pytest + +from py_alpaca_api import PyAlpacaAPI +from py_alpaca_api.exceptions import APIRequestError + + +@pytest.mark.skipif( + not os.environ.get("ALPACA_API_KEY") or not os.environ.get("ALPACA_SECRET_KEY"), + reason="API credentials not set", +) +class TestOrderEnhancementsIntegration: + @pytest.fixture + def alpaca(self): + return PyAlpacaAPI( + api_key=os.environ.get("ALPACA_API_KEY"), + api_secret=os.environ.get("ALPACA_SECRET_KEY"), + api_paper=True, + ) + + @pytest.fixture(autouse=True) + def cleanup_orders(self, alpaca): + """Cancel all orders before and after each test.""" + # Cancel before test + with contextlib.suppress(Exception): + alpaca.trading.orders.cancel_all() + + yield + + # Cancel after test + with contextlib.suppress(Exception): + alpaca.trading.orders.cancel_all() + + def test_market_order_with_client_order_id(self, alpaca): + # Submit order with client ID + client_id = f"test-market-{int(time.time())}" + order = alpaca.trading.orders.market( + symbol="AAPL", + qty=1, + side="buy", + client_order_id=client_id, + ) + + assert order.client_order_id == client_id + assert order.symbol == "AAPL" + assert order.qty == 1 + + # Retrieve order by client ID + retrieved = alpaca.trading.orders.get_by_client_order_id(client_id) + assert retrieved.id == order.id + assert retrieved.client_order_id == client_id + + # Cancel by client ID + result = alpaca.trading.orders.cancel_by_client_order_id(client_id) + assert "cancelled" in result.lower() + + def test_limit_order_with_extended_hours(self, alpaca): + # Submit limit order with extended hours + client_id = f"test-limit-ext-{int(time.time())}" + order = alpaca.trading.orders.limit( + symbol="AAPL", + limit_price=150.00, + qty=1, + side="buy", + extended_hours=True, + client_order_id=client_id, + time_in_force="day", + ) + + assert order.extended_hours is True + assert order.client_order_id == client_id + + # Cancel the order + alpaca.trading.orders.cancel_by_id(order.id) + + def test_replace_order(self, alpaca): + # Submit initial limit order + order = alpaca.trading.orders.limit( + symbol="AAPL", + limit_price=100.00, # Very low price to avoid fill + qty=1, + side="buy", + time_in_force="gtc", + ) + + assert order.qty == 1 + assert order.limit_price == 100.00 + + # Wait a moment for order to be fully registered + time.sleep(0.5) + + # Only test replace if order is in correct state + if order.status in ["new", "partially_filled"]: + # Replace order with new parameters + replaced_order = alpaca.trading.orders.replace_order( + order_id=order.id, + qty=2, + limit_price=101.00, + ) + + assert replaced_order.qty == 2 + assert replaced_order.limit_price == 101.00 + assert replaced_order.symbol == "AAPL" + + # Cancel the order + alpaca.trading.orders.cancel_by_id(replaced_order.id) + else: + # Order already accepted/filled, just cancel it + with contextlib.suppress(Exception): + alpaca.trading.orders.cancel_by_id(order.id) + + def test_order_class_oto(self, alpaca): + """Test One-Triggers-Other (OTO) order class.""" + # Note: OTO orders require specific account permissions + # This test may fail if the account doesn't support OTO orders + try: + client_id = f"test-oto-{int(time.time())}" + order = alpaca.trading.orders.limit( + symbol="AAPL", + limit_price=100.00, # Low price to avoid fill + qty=1, + side="buy", + order_class="oto", + take_profit=150.00, # OTO needs either take_profit OR stop_loss + client_order_id=client_id, + ) + + if order: + assert order.order_class in ["oto", "simple"] # May fallback to simple + alpaca.trading.orders.cancel_by_id(order.id) + except APIRequestError as e: + # OTO might not be supported + if "order class" not in str(e).lower(): + raise + + def test_order_class_oco(self, alpaca): + """Test One-Cancels-Other (OCO) order class.""" + # Note: OCO orders are exit-only orders and require an existing position + # Since we may not have a position, we'll test that the API properly rejects + # OCO orders when no position exists, or skip if we get the expected error + try: + client_id = f"test-oco-{int(time.time())}" + order = alpaca.trading.orders.limit( + symbol="AAPL", + limit_price=100.00, # Low price to avoid fill + qty=1, + side="sell", # OCO orders are exit orders, so we'd sell to close a long position + order_class="oco", + take_profit=150.00, # Take profit at higher price for sell order + stop_loss=80.00, # Stop loss at lower price for sell order + client_order_id=client_id, + ) + + # If we somehow have a position and the order succeeds + if order: + assert order.order_class in ["oco", "simple"] # May fallback to simple + alpaca.trading.orders.cancel_by_id(order.id) + except APIRequestError as e: + # Expected errors since OCO orders are exit-only + error_msg = str(e).lower() + if "oco orders must be exit orders" in error_msg: + pass # Expected since we don't have a position + elif "insufficient" in error_msg or "position" in error_msg: + pass # Also expected if no position exists + elif "order class" in error_msg: + pass # OCO might not be supported on account + else: + raise # Unexpected error + + def test_bracket_order_with_explicit_class(self, alpaca): + """Test bracket order with explicit order_class.""" + client_id = f"test-bracket-{int(time.time())}" + order = alpaca.trading.orders.market( + symbol="AAPL", + qty=1, + side="buy", + take_profit=500.00, # High take profit (well above current price ~$240) + stop_loss=50.00, # Low stop loss + order_class="bracket", # Explicitly set + client_order_id=client_id, + ) + + assert order.order_class == "bracket" + assert order.client_order_id == client_id + + # Cancel the order + alpaca.trading.orders.cancel_by_id(order.id) + + def test_stop_order_with_client_id(self, alpaca): + client_id = f"test-stop-{int(time.time())}" + order = alpaca.trading.orders.stop( + symbol="AAPL", + stop_price=300.00, # High stop price for buy to avoid trigger + qty=1, + side="buy", + client_order_id=client_id, + ) + + assert order.client_order_id == client_id + assert order.stop_price == 300.00 + + # Cancel the order + alpaca.trading.orders.cancel_by_client_order_id(client_id) + + def test_trailing_stop_with_enhancements(self, alpaca): + client_id = f"test-trail-{int(time.time())}" + order = alpaca.trading.orders.trailing_stop( + symbol="AAPL", + qty=1, + trail_percent=10.0, # 10% trailing stop + side="sell", + client_order_id=client_id, + ) + + assert order.client_order_id == client_id + assert order.trail_percent == 10.0 + + # Cancel the order + alpaca.trading.orders.cancel_by_id(order.id) + + def test_multiple_orders_with_client_ids(self, alpaca): + """Test managing multiple orders with client IDs.""" + client_ids = [f"test-multi-{i}-{int(time.time())}" for i in range(3)] + orders = [] + + # Submit multiple orders + for i, client_id in enumerate(client_ids): + order = alpaca.trading.orders.limit( + symbol="AAPL", + limit_price=100.00 + i, # Different prices + qty=1, + side="buy", + client_order_id=client_id, + ) + orders.append(order) + + # Verify we can retrieve each by client ID + for client_id, order in zip(client_ids, orders, strict=False): + retrieved = alpaca.trading.orders.get_by_client_order_id(client_id) + assert retrieved.id == order.id + assert retrieved.client_order_id == client_id + + # Cancel all orders + for client_id in client_ids: + alpaca.trading.orders.cancel_by_client_order_id(client_id) + + def test_replace_order_time_in_force(self, alpaca): + """Test replacing order's time_in_force parameter.""" + # Submit initial order with day time_in_force + order = alpaca.trading.orders.limit( + symbol="AAPL", + limit_price=100.00, + qty=1, + side="buy", + time_in_force="day", + ) + + assert order.time_in_force == "day" + + # Wait a moment for order to be fully registered + time.sleep(0.5) + + # Only test replace if order is in correct state + if order.status in ["new", "partially_filled"]: + # Replace with gtc time_in_force + replaced = alpaca.trading.orders.replace_order( + order_id=order.id, + time_in_force="gtc", + ) + + assert replaced.time_in_force == "gtc" + assert replaced.qty == order.qty # Qty should remain the same + + # Cancel the order + alpaca.trading.orders.cancel_by_id(replaced.id) + else: + # Order already accepted/filled, just cancel it + with contextlib.suppress(Exception): + alpaca.trading.orders.cancel_by_id(order.id) diff --git a/tests/test_integration/test_snapshots_integration.py b/tests/test_integration/test_snapshots_integration.py new file mode 100644 index 0000000..3a4790a --- /dev/null +++ b/tests/test_integration/test_snapshots_integration.py @@ -0,0 +1,217 @@ +import os + +import pytest + +from py_alpaca_api import PyAlpacaAPI +from py_alpaca_api.models.snapshot_model import BarModel, SnapshotModel + + +@pytest.mark.skipif( + not os.environ.get("ALPACA_API_KEY") or not os.environ.get("ALPACA_SECRET_KEY"), + reason="API credentials not set", +) +class TestSnapshotsIntegration: + @pytest.fixture + def alpaca(self): + return PyAlpacaAPI( + api_key=os.environ.get("ALPACA_API_KEY"), + api_secret=os.environ.get("ALPACA_SECRET_KEY"), + api_paper=True, + ) + + def test_get_snapshot_single_symbol(self, alpaca): + result = alpaca.stock.snapshots.get_snapshot("AAPL") + + assert isinstance(result, SnapshotModel) + assert result.symbol == "AAPL" + + # At least one of these should be present during market hours + # Some may be None during pre/post market + assert any( + [ + result.latest_trade is not None, + result.latest_quote is not None, + result.minute_bar is not None, + result.daily_bar is not None, + result.prev_daily_bar is not None, + ] + ) + + # If latest_trade exists, validate its structure + if result.latest_trade: + assert result.latest_trade.price > 0 + assert result.latest_trade.size >= 0 + assert result.latest_trade.symbol == "AAPL" + + # If latest_quote exists, validate its structure + if result.latest_quote: + assert result.latest_quote.ask >= 0 + assert result.latest_quote.bid >= 0 + assert result.latest_quote.symbol == "AAPL" + + # If minute_bar exists, validate its structure + if result.minute_bar: + assert isinstance(result.minute_bar, BarModel) + assert result.minute_bar.open > 0 + assert result.minute_bar.high >= result.minute_bar.low + assert result.minute_bar.volume >= 0 + + # If daily_bar exists, validate its structure + if result.daily_bar: + assert isinstance(result.daily_bar, BarModel) + assert result.daily_bar.open > 0 + assert result.daily_bar.high >= result.daily_bar.low + assert result.daily_bar.volume >= 0 + + # If prev_daily_bar exists, validate its structure + if result.prev_daily_bar: + assert isinstance(result.prev_daily_bar, BarModel) + assert result.prev_daily_bar.open > 0 + assert result.prev_daily_bar.high >= result.prev_daily_bar.low + assert result.prev_daily_bar.volume >= 0 + + def test_get_snapshot_with_different_feeds(self, alpaca): + # Test with IEX feed (default) + result_iex = alpaca.stock.snapshots.get_snapshot("AAPL", feed="iex") + assert isinstance(result_iex, SnapshotModel) + assert result_iex.symbol == "AAPL" + + # Test with SIP feed if available (might require subscription) + try: + result_sip = alpaca.stock.snapshots.get_snapshot("AAPL", feed="sip") + assert isinstance(result_sip, SnapshotModel) + assert result_sip.symbol == "AAPL" + except Exception: + # SIP feed might not be available for all accounts + pass + + def test_get_snapshots_multiple_symbols(self, alpaca): + symbols = ["AAPL", "MSFT", "GOOGL"] + result = alpaca.stock.snapshots.get_snapshots(symbols) + + assert isinstance(result, dict) + assert len(result) == 3 + + for symbol in symbols: + assert symbol in result + assert isinstance(result[symbol], SnapshotModel) + assert result[symbol].symbol == symbol + + # Validate at least some data is present + snapshot = result[symbol] + assert any( + [ + snapshot.latest_trade is not None, + snapshot.latest_quote is not None, + snapshot.minute_bar is not None, + snapshot.daily_bar is not None, + snapshot.prev_daily_bar is not None, + ] + ) + + def test_get_snapshots_single_symbol_returns_list(self, alpaca): + result = alpaca.stock.snapshots.get_snapshots("AAPL") + + assert isinstance(result, list) + assert len(result) == 1 + assert isinstance(result[0], SnapshotModel) + assert result[0].symbol == "AAPL" + + def test_get_snapshots_comma_separated_string(self, alpaca): + result = alpaca.stock.snapshots.get_snapshots("AAPL,MSFT,TSLA") + + assert isinstance(result, dict) + assert len(result) == 3 + assert all(symbol in result for symbol in ["AAPL", "MSFT", "TSLA"]) + + def test_get_snapshots_large_batch(self, alpaca): + # Test with a larger batch of symbols + symbols = [ + "AAPL", + "MSFT", + "GOOGL", + "AMZN", + "META", + "TSLA", + "NVDA", + "JPM", + "V", + "JNJ", + ] + result = alpaca.stock.snapshots.get_snapshots(symbols) + + assert isinstance(result, dict) + assert len(result) == 10 + + for symbol in symbols: + assert symbol in result + assert isinstance(result[symbol], SnapshotModel) + assert result[symbol].symbol == symbol + + def test_get_snapshot_with_otc_symbol(self, alpaca): + # Test with an OTC symbol if available + try: + result = alpaca.stock.snapshots.get_snapshot("TCEHY", feed="otc") + assert isinstance(result, SnapshotModel) + assert result.symbol == "TCEHY" + except Exception: + # OTC feed might not be available or symbol might not exist + pass + + def test_snapshot_data_consistency(self, alpaca): + # Get snapshot and verify data consistency + result = alpaca.stock.snapshots.get_snapshot("SPY") + + assert isinstance(result, SnapshotModel) + assert result.symbol == "SPY" + + # If we have both minute and daily bars, check consistency + if result.minute_bar and result.daily_bar: + # Minute bar should be within the daily bar's range + assert result.minute_bar.high <= result.daily_bar.high + assert result.minute_bar.low >= result.daily_bar.low + + # If we have latest trade and minute bar, check consistency + if result.latest_trade and result.minute_bar: + # Latest trade should be within reasonable range of minute bar + # (allowing for some time difference) + assert result.latest_trade.price > 0 + assert result.minute_bar.close > 0 + + def test_get_snapshots_handles_invalid_symbol_gracefully(self, alpaca): + # Mix valid and potentially invalid symbols + symbols = ["AAPL", "INVALID123", "MSFT"] + + try: + result = alpaca.stock.snapshots.get_snapshots(symbols) + # If it doesn't error, check we at least got valid symbols + assert "AAPL" in result or "MSFT" in result + except Exception: + # API might reject the entire request with invalid symbols + # This is acceptable behavior + pass + + def test_snapshot_during_market_hours(self, alpaca): + # This test is most meaningful during market hours + # Get snapshot for a highly liquid symbol + result = alpaca.stock.snapshots.get_snapshot("SPY") + + assert isinstance(result, SnapshotModel) + assert result.symbol == "SPY" + + # During market hours, SPY should have most data available + # Note: This might fail outside market hours + try: + market_info = alpaca.trading.market.get_market_info() + market_open = ( + market_info.is_open if hasattr(market_info, "is_open") else False + ) + except Exception: + market_open = False + + if market_open: + # During market hours, we expect more data to be available + assert result.latest_trade is not None + assert result.latest_quote is not None + # Minute bar might not be immediately available at market open + # Daily bar updates throughout the day diff --git a/tests/test_stock/test_batch_operations.py b/tests/test_stock/test_batch_operations.py new file mode 100644 index 0000000..e49888c --- /dev/null +++ b/tests/test_stock/test_batch_operations.py @@ -0,0 +1,398 @@ +"""Tests for batch operations in history and latest_quote modules.""" + +import os +from unittest.mock import MagicMock, patch + +import pandas as pd +import pytest + +from py_alpaca_api import PyAlpacaAPI +from py_alpaca_api.models.quote_model import QuoteModel + + +class TestBatchOperations: + """Test batch operations for multi-symbol data retrieval.""" + + @pytest.fixture + def alpaca(self): + """Create PyAlpacaAPI instance for testing.""" + return PyAlpacaAPI( + api_key=os.environ.get("ALPACA_API_KEY", "test_key"), + api_secret=os.environ.get("ALPACA_SECRET_KEY", "test_secret"), + api_paper=True, + ) + + @pytest.fixture + def mock_requests(self): + """Mock the Requests class for unit tests.""" + with patch("py_alpaca_api.stock.history.Requests") as mock: + yield mock + + @pytest.fixture + def mock_quotes_requests(self): + """Mock the Requests class for quote tests.""" + with patch("py_alpaca_api.stock.latest_quote.Requests") as mock: + yield mock + + def test_history_single_symbol(self, alpaca, mock_requests): + """Test getting historical data for a single symbol.""" + # Setup mock response + mock_response = MagicMock() + mock_response.text = '{"bars": [{"t": "2024-01-01T09:30:00Z", "o": 100, "h": 105, "l": 99, "c": 103, "v": 1000000, "n": 500, "vw": 102.5}]}' + mock_requests.return_value.request.return_value = mock_response + + # Mock the asset check + with patch.object(alpaca.stock.history, "check_if_stock", return_value=None): + # Test single symbol + df = alpaca.stock.history.get_stock_data("AAPL", "2024-01-01", "2024-01-02") + + assert isinstance(df, pd.DataFrame) + assert not df.empty + assert "symbol" in df.columns + assert df["symbol"].iloc[0] == "AAPL" + + def test_history_multiple_symbols(self, alpaca, mock_requests): + """Test getting historical data for multiple symbols.""" + # Setup mock response for multi-symbol request + mock_response = MagicMock() + mock_response.text = """{ + "bars": { + "AAPL": [{"t": "2024-01-01T09:30:00Z", "o": 100, "h": 105, "l": 99, "c": 103, "v": 1000000, "n": 500, "vw": 102.5}], + "GOOGL": [{"t": "2024-01-01T09:30:00Z", "o": 150, "h": 155, "l": 149, "c": 153, "v": 800000, "n": 400, "vw": 152.5}] + } + }""" + mock_requests.return_value.request.return_value = mock_response + + # Mock the asset check + with patch.object(alpaca.stock.history, "check_if_stock", return_value=None): + # Test multiple symbols + symbols = ["AAPL", "GOOGL"] + df = alpaca.stock.history.get_stock_data( + symbols, "2024-01-01", "2024-01-02" + ) + + assert isinstance(df, pd.DataFrame) + assert not df.empty + assert set(df["symbol"].unique()) == {"AAPL", "GOOGL"} + assert len(df) == 2 # One row per symbol + + def test_history_batch_large_symbol_list(self, alpaca, mock_requests): + """Test batching when more than 200 symbols are requested.""" + # Create a list of 250 symbols + symbols = [f"STOCK{i:03d}" for i in range(250)] + + # Setup mock responses for batches + batch1_response = { + "bars": { + f"STOCK{i:03d}": [ + { + "t": "2024-01-01T09:30:00Z", + "o": 100, + "h": 105, + "l": 99, + "c": 103, + "v": 1000000, + "n": 500, + "vw": 102.5, + } + ] + for i in range(200) + } + } + batch2_response = { + "bars": { + f"STOCK{i:03d}": [ + { + "t": "2024-01-01T09:30:00Z", + "o": 100, + "h": 105, + "l": 99, + "c": 103, + "v": 1000000, + "n": 500, + "vw": 102.5, + } + ] + for i in range(200, 250) + } + } + + responses = [MagicMock(), MagicMock()] + responses[0].text = str(batch1_response).replace("'", '"') + responses[1].text = str(batch2_response).replace("'", '"') + + mock_requests.return_value.request.side_effect = responses + + # Mock the batching method directly since it uses ThreadPoolExecutor + with ( + patch.object(alpaca.stock.history, "check_if_stock", return_value=None), + patch.object(alpaca.stock.history, "_get_batched_stock_data") as mock_batch, + ): + # Return a simple DataFrame for testing + mock_batch.return_value = pd.DataFrame( + { + "symbol": ["STOCK000", "STOCK100"], + "date": pd.to_datetime(["2024-01-01", "2024-01-01"]), + "open": [100.0, 100.0], + "high": [105.0, 105.0], + "low": [99.0, 99.0], + "close": [103.0, 103.0], + "volume": [1000000, 1000000], + "trade_count": [500, 500], + "vwap": [102.5, 102.5], + } + ) + + df = alpaca.stock.history.get_stock_data( + symbols, "2024-01-01", "2024-01-02" + ) + + # Verify batching method was called + mock_batch.assert_called_once() + + # Verify result has data + assert not df.empty + + def test_latest_quote_single_symbol(self, alpaca, mock_quotes_requests): + """Test getting latest quote for a single symbol.""" + # Setup mock response + mock_response = MagicMock() + mock_response.text = """{ + "quotes": { + "AAPL": {"t": "2024-01-01T15:59:59Z", "ap": 103.5, "as": 100, "bp": 103.0, "bs": 100} + } + }""" + mock_quotes_requests.return_value.request.return_value = mock_response + + # Test single symbol + quote = alpaca.stock.latest_quote.get("AAPL") + + assert isinstance(quote, QuoteModel) + assert quote.symbol == "AAPL" + assert quote.ask == 103.5 + assert quote.bid == 103.0 + + def test_latest_quote_multiple_symbols(self, alpaca, mock_quotes_requests): + """Test getting latest quotes for multiple symbols.""" + # Setup mock response + mock_response = MagicMock() + mock_response.text = """{ + "quotes": { + "AAPL": {"t": "2024-01-01T15:59:59Z", "ap": 103.5, "as": 100, "bp": 103.0, "bs": 100}, + "GOOGL": {"t": "2024-01-01T15:59:59Z", "ap": 153.5, "as": 100, "bp": 153.0, "bs": 100} + } + }""" + mock_quotes_requests.return_value.request.return_value = mock_response + + # Test multiple symbols + quotes = alpaca.stock.latest_quote.get(["AAPL", "GOOGL"]) + + assert isinstance(quotes, list) + assert len(quotes) == 2 + assert all(isinstance(q, QuoteModel) for q in quotes) + symbols = {q.symbol for q in quotes} + assert symbols == {"AAPL", "GOOGL"} + + def test_latest_quote_batch_large_symbol_list(self, alpaca, mock_quotes_requests): + """Test batching when more than 200 symbols are requested.""" + # Create a list of 250 symbols + symbols = [f"STOCK{i:03d}" for i in range(250)] + + # Setup mock responses for batches + batch1_response = { + "quotes": { + f"STOCK{i:03d}": { + "t": "2024-01-01T15:59:59Z", + "ap": 103.5, + "as": 100, + "bp": 103.0, + "bs": 100, + } + for i in range(200) + } + } + batch2_response = { + "quotes": { + f"STOCK{i:03d}": { + "t": "2024-01-01T15:59:59Z", + "ap": 103.5, + "as": 100, + "bp": 103.0, + "bs": 100, + } + for i in range(200, 250) + } + } + + responses = [MagicMock(), MagicMock()] + responses[0].text = str(batch1_response).replace("'", '"') + responses[1].text = str(batch2_response).replace("'", '"') + + mock_quotes_requests.return_value.request.side_effect = responses + + with patch.object( + alpaca.stock.latest_quote, "_get_batched_quotes" + ) as mock_batch: + # Return mock quotes + mock_batch.return_value = [ + QuoteModel( + symbol=f"STOCK{i:03d}", + timestamp="2024-01-01T15:59:59Z", + ask=103.5, + ask_size=100, + bid=103.0, + bid_size=100, + ) + for i in range(250) + ] + + quotes = alpaca.stock.latest_quote.get(symbols) + + # Verify batching method was called + mock_batch.assert_called_once() + assert len(quotes) == 250 + + def test_history_empty_response_handling(self, alpaca, mock_requests): + """Test handling of empty responses in history.""" + # Setup mock response with no data + mock_response = MagicMock() + mock_response.text = '{"bars": {}}' + mock_requests.return_value.request.return_value = mock_response + + with ( + patch.object(alpaca.stock.history, "check_if_stock", return_value=None), + pytest.raises(Exception, match="No historical data found"), + ): + alpaca.stock.history.get_stock_data(["INVALID"], "2024-01-01", "2024-01-02") + + def test_latest_quote_empty_response_handling(self, alpaca, mock_quotes_requests): + """Test handling of empty responses in quotes.""" + # Setup mock response with no data + mock_response = MagicMock() + mock_response.text = '{"quotes": {}}' + mock_quotes_requests.return_value.request.return_value = mock_response + + quotes = alpaca.stock.latest_quote.get(["INVALID"]) + + assert quotes == [] + + def test_history_concurrent_batch_error_handling(self, alpaca, mock_requests): + """Test error handling in concurrent batch requests for history.""" + # Create a list of 250 symbols + symbols = [f"STOCK{i:03d}" for i in range(250)] + + # Setup one successful and one failing response + success_response = MagicMock() + success_response.text = """{ + "bars": { + "STOCK000": [{"t": "2024-01-01T09:30:00Z", "o": 100, "h": 105, "l": 99, "c": 103, "v": 1000000, "n": 500, "vw": 102.5}] + } + }""" + + # Make second batch fail + mock_requests.return_value.request.side_effect = [ + success_response, + Exception("API Error"), + ] + + # Should continue despite one batch failing + with ( + patch.object(alpaca.stock.history, "check_if_stock", return_value=None), + patch.object(alpaca.stock.history, "_get_batched_stock_data") as mock_batch, + ): + # Simulate partial failure - return DataFrame with some data + mock_batch.return_value = pd.DataFrame( + { + "symbol": ["STOCK000"], + "date": pd.to_datetime(["2024-01-01"]), + "open": [100.0], + "high": [105.0], + "low": [99.0], + "close": [103.0], + "volume": [1000000], + "trade_count": [500], + "vwap": [102.5], + } + ) + + df = alpaca.stock.history.get_stock_data( + symbols, "2024-01-01", "2024-01-02" + ) + # Should return partial data + assert not df.empty + + def test_quote_concurrent_batch_error_handling(self, alpaca, mock_quotes_requests): + """Test error handling in concurrent batch requests for quotes.""" + # Create a list of 250 symbols + symbols = [f"STOCK{i:03d}" for i in range(250)] + + # Setup one successful and one failing response + success_response = MagicMock() + success_response.text = """{ + "quotes": { + "STOCK000": {"t": "2024-01-01T15:59:59Z", "ap": 103.5, "as": 100, "bp": 103.0, "bs": 100} + } + }""" + + # Make second batch fail + mock_quotes_requests.return_value.request.side_effect = [ + success_response, + Exception("API Error"), + ] + + # Should continue despite one batch failing + with patch.object( + alpaca.stock.latest_quote, "_get_batched_quotes" + ) as mock_batch: + # Simulate partial success - return some quotes + mock_batch.return_value = [ + QuoteModel( + symbol="STOCK000", + timestamp="2024-01-01T15:59:59Z", + ask=103.5, + ask_size=100, + bid=103.0, + bid_size=100, + ) + ] + + quotes = alpaca.stock.latest_quote.get(symbols) + # Should return partial data + assert len(quotes) == 1 + + def test_history_dataframe_optimization(self, alpaca, mock_requests): + """Test DataFrame operations are optimized.""" + # Setup mock response with multiple bars per symbol + mock_response = MagicMock() + mock_response.text = """{ + "bars": { + "AAPL": [ + {"t": "2024-01-01T09:30:00Z", "o": 100, "h": 105, "l": 99, "c": 103, "v": 1000000, "n": 500, "vw": 102.5}, + {"t": "2024-01-01T10:30:00Z", "o": 103, "h": 107, "l": 102, "c": 106, "v": 1200000, "n": 600, "vw": 105.5} + ], + "GOOGL": [ + {"t": "2024-01-01T09:30:00Z", "o": 150, "h": 155, "l": 149, "c": 153, "v": 800000, "n": 400, "vw": 152.5}, + {"t": "2024-01-01T10:30:00Z", "o": 153, "h": 157, "l": 152, "c": 156, "v": 900000, "n": 450, "vw": 155.5} + ] + } + }""" + mock_requests.return_value.request.return_value = mock_response + + with patch.object(alpaca.stock.history, "check_if_stock", return_value=None): + # Test DataFrame is properly sorted and indexed + df = alpaca.stock.history.get_stock_data( + ["AAPL", "GOOGL"], "2024-01-01", "2024-01-02" + ) + + # Check DataFrame structure + assert isinstance(df, pd.DataFrame) + assert len(df) == 4 # 2 bars per symbol + + # Check sorting by symbol and date + sorted_check = df.equals(df.sort_values(["symbol", "date"])) + assert sorted_check + + # Check data types are properly set + assert df["open"].dtype == "float64" + assert df["volume"].dtype == "int64" + assert pd.api.types.is_datetime64_any_dtype(df["date"]) diff --git a/tests/test_stock/test_batch_operations_integration.py b/tests/test_stock/test_batch_operations_integration.py new file mode 100644 index 0000000..31db584 --- /dev/null +++ b/tests/test_stock/test_batch_operations_integration.py @@ -0,0 +1,226 @@ +"""Integration tests for batch operations with real API calls.""" + +import os +from datetime import datetime, timedelta + +import pandas as pd +import pytest + +from py_alpaca_api import PyAlpacaAPI +from py_alpaca_api.models.quote_model import QuoteModel + + +@pytest.mark.skipif( + not os.environ.get("ALPACA_API_KEY"), + reason="ALPACA_API_KEY not set in environment", +) +class TestBatchOperationsIntegration: + """Integration tests for batch operations with real API.""" + + @pytest.fixture + def alpaca(self): + """Create PyAlpacaAPI instance for testing.""" + return PyAlpacaAPI( + api_key=os.environ.get("ALPACA_API_KEY"), + api_secret=os.environ.get("ALPACA_SECRET_KEY"), + api_paper=True, + ) + + @pytest.fixture + def test_symbols(self): + """Common test symbols.""" + return ["AAPL", "GOOGL", "MSFT", "AMZN", "TSLA"] + + @pytest.fixture + def date_range(self): + """Get a valid date range for historical data.""" + end_date = datetime.now() - timedelta(days=1) + start_date = end_date - timedelta(days=5) + return start_date.strftime("%Y-%m-%d"), end_date.strftime("%Y-%m-%d") + + def test_multi_symbol_history_real_data(self, alpaca, test_symbols, date_range): + """Test fetching real historical data for multiple symbols.""" + start, end = date_range + + df = alpaca.stock.history.get_stock_data( + test_symbols, start, end, timeframe="1d" + ) + + # Validate response + assert isinstance(df, pd.DataFrame) + assert not df.empty + + # Check all symbols are present + returned_symbols = set(df["symbol"].unique()) + assert returned_symbols.issubset(set(test_symbols)) + + # Validate columns + expected_columns = [ + "symbol", + "date", + "open", + "high", + "low", + "close", + "volume", + "trade_count", + "vwap", + ] + for col in expected_columns: + assert col in df.columns + + # Validate data types + assert pd.api.types.is_datetime64_any_dtype(df["date"]) + assert df["open"].dtype == "float64" + assert df["volume"].dtype == "int64" + + def test_multi_symbol_quotes_real_data(self, alpaca, test_symbols): + """Test fetching real latest quotes for multiple symbols.""" + quotes = alpaca.stock.latest_quote.get(test_symbols) + + # Validate response + assert isinstance(quotes, list) + assert len(quotes) > 0 + assert all(isinstance(q, QuoteModel) for q in quotes) + + # Check quote attributes + for quote in quotes: + assert quote.symbol in test_symbols + assert quote.ask >= 0 + assert quote.bid >= 0 + assert quote.ask_size >= 0 + assert quote.bid_size >= 0 + assert quote.timestamp is not None + + def test_single_symbol_history_backward_compatibility(self, alpaca, date_range): + """Test that single symbol requests still work as before.""" + start, end = date_range + + df = alpaca.stock.history.get_stock_data("AAPL", start, end, timeframe="1d") + + # Validate response + assert isinstance(df, pd.DataFrame) + assert not df.empty + assert all(df["symbol"] == "AAPL") + + def test_single_symbol_quote_backward_compatibility(self, alpaca): + """Test that single symbol quote requests still work as before.""" + quote = alpaca.stock.latest_quote.get("AAPL") + + # Should return single QuoteModel, not a list + assert isinstance(quote, QuoteModel) + assert quote.symbol == "AAPL" + + def test_large_batch_symbols(self, alpaca): + """Test with a larger batch of symbols (but under 200).""" + # Get a list of popular symbols + large_symbols = [ + "AAPL", + "GOOGL", + "MSFT", + "AMZN", + "TSLA", + "META", + "NVDA", + "JPM", + "V", + "JNJ", + "WMT", + "PG", + "UNH", + "MA", + "DIS", + "HD", + "BAC", + "VZ", + "ADBE", + "NFLX", + "CMCSA", + "PFE", + "KO", + "PEP", + "TMO", + "CSCO", + "ABT", + "NKE", + "CVX", + "XOM", + ] + + quotes = alpaca.stock.latest_quote.get(large_symbols) + + assert isinstance(quotes, list) + assert len(quotes) > 20 # Most should succeed + returned_symbols = {q.symbol for q in quotes} + assert len(returned_symbols) > 20 + + def test_mixed_valid_invalid_symbols(self, alpaca): + """Test handling of mixed valid and invalid symbols.""" + # Note: The Alpaca API returns an error if any symbol is invalid + # So we'll test with valid symbols only + valid_symbols = ["AAPL", "GOOGL", "MSFT", "META", "AMZN"] + + quotes = alpaca.stock.latest_quote.get(valid_symbols) + + # Should return quotes for all symbols + assert isinstance(quotes, list) + returned_symbols = {q.symbol for q in quotes} + assert "AAPL" in returned_symbols + assert "GOOGL" in returned_symbols + assert "MSFT" in returned_symbols + + def test_different_timeframes(self, alpaca, date_range): + """Test multi-symbol history with different timeframes.""" + start, end = date_range + symbols = ["AAPL", "GOOGL"] + + # Test daily bars + df_daily = alpaca.stock.history.get_stock_data( + symbols, start, end, timeframe="1d" + ) + assert not df_daily.empty + + # Test hourly bars (if market hours) + df_hourly = alpaca.stock.history.get_stock_data( + symbols, start, end, timeframe="1h", limit=50 + ) + assert not df_hourly.empty + + # Hourly should have more data points than daily + assert len(df_hourly) >= len(df_daily) + + def test_different_feeds(self, alpaca): + """Test quotes with different feed sources.""" + symbol = "AAPL" + + # Test IEX feed (default) + quote_iex = alpaca.stock.latest_quote.get(symbol, feed="iex") + assert isinstance(quote_iex, QuoteModel) + + # Test SIP feed (may require subscription) + try: + quote_sip = alpaca.stock.latest_quote.get(symbol, feed="sip") + assert isinstance(quote_sip, QuoteModel) + except Exception: + # SIP feed might not be available for all accounts + pass + + def test_pagination_handling(self, alpaca): + """Test that pagination works correctly for large data requests.""" + # Request a large amount of historical data that will paginate + end_date = datetime.now() - timedelta(days=1) + start_date = end_date - timedelta(days=30) + + symbols = ["AAPL", "GOOGL"] + df = alpaca.stock.history.get_stock_data( + symbols, + start_date.strftime("%Y-%m-%d"), + end_date.strftime("%Y-%m-%d"), + timeframe="15m", + limit=10000, # Large limit to trigger pagination + ) + + # Should handle pagination transparently + assert isinstance(df, pd.DataFrame) + assert not df.empty + assert len(df) > 100 # Should have many data points diff --git a/tests/test_stock/test_metadata.py b/tests/test_stock/test_metadata.py new file mode 100644 index 0000000..6eccc41 --- /dev/null +++ b/tests/test_stock/test_metadata.py @@ -0,0 +1,306 @@ +import json +from unittest.mock import MagicMock, patch + +import pytest + +from py_alpaca_api.exceptions import APIRequestError, ValidationError +from py_alpaca_api.stock.metadata import Metadata + + +class TestMetadata: + @pytest.fixture + def metadata(self): + headers = { + "APCA-API-KEY-ID": "test_key", + "APCA-API-SECRET-KEY": "test_secret", + } + return Metadata(headers=headers) + + @pytest.fixture + def mock_exchange_response(self): + return { + "A": "NYSE American (AMEX)", + "B": "NASDAQ OMX BX", + "C": "National Stock Exchange", + "D": "FINRA ADF", + "E": "Market Independent", + "H": "MIAX", + "I": "International Securities Exchange", + "J": "Cboe EDGA", + "K": "Cboe EDGX", + "L": "Long Term Stock Exchange", + "M": "Chicago Stock Exchange", + "N": "New York Stock Exchange", + "P": "NYSE Arca", + "Q": "NASDAQ", + "S": "NASDAQ Small Cap", + "T": "NASDAQ Int", + "U": "Members Exchange", + "V": "IEX", + "W": "CBOE", + "X": "NASDAQ OMX PSX", + "Y": "Cboe BYX", + "Z": "Cboe BZX", + } + + @pytest.fixture + def mock_condition_response(self): + return { + "": "Regular Sale", + "4": "Derivatively Priced", + "5": "Market Center Reopening Trade", + "6": "Market Center Closing Trade", + "7": "Qualified Contingent Trade", + "B": "Average Price Trade", + "C": "Cash Sale", + "E": "Automatic Execution", + "F": "Intermarket Sweep", + "H": "Price Variation Trade", + "I": "Odd Lot Trade", + "K": "Rule 127 NYSE", + "L": "Sold Last", + "M": "Market Center Official Close", + "N": "Next Day", + "O": "Market Center Opening Trade", + "P": "Prior Reference Price", + "Q": "Market Center Official Open", + "R": "Seller", + "S": "Split Trade", + "T": "Form T", + "U": "Extended Trading Hours", + "V": "Contingent Trade", + "W": "Average Price Trade", + "X": "Cross/Periodic Auction Trade", + "Y": "Yellow Flag Regular Trade", + "Z": "Sold Out of Sequence", + } + + def test_get_exchange_codes(self, metadata, mock_exchange_response): + with patch("py_alpaca_api.stock.metadata.Requests") as mock_requests: + mock_response = MagicMock() + mock_response.text = json.dumps(mock_exchange_response) + mock_requests.return_value.request.return_value = mock_response + + result = metadata.get_exchange_codes() + + assert isinstance(result, dict) + assert len(result) == 22 + assert result["A"] == "NYSE American (AMEX)" + assert result["Q"] == "NASDAQ" + assert result["N"] == "New York Stock Exchange" + assert result["V"] == "IEX" + + mock_requests.return_value.request.assert_called_once_with( + method="GET", + url=f"{metadata.base_url}/exchanges", + headers=metadata.headers, + ) + + def test_get_exchange_codes_with_cache(self, metadata, mock_exchange_response): + with patch("py_alpaca_api.stock.metadata.Requests") as mock_requests: + mock_response = MagicMock() + mock_response.text = json.dumps(mock_exchange_response) + mock_requests.return_value.request.return_value = mock_response + + # First call should hit API + result1 = metadata.get_exchange_codes() + # Second call should use cache + result2 = metadata.get_exchange_codes() + + assert result1 == result2 + # API should only be called once due to caching + assert mock_requests.return_value.request.call_count == 1 + + def test_get_exchange_codes_without_cache(self, metadata, mock_exchange_response): + with patch("py_alpaca_api.stock.metadata.Requests") as mock_requests: + mock_response = MagicMock() + mock_response.text = json.dumps(mock_exchange_response) + mock_requests.return_value.request.return_value = mock_response + + # First call + metadata.get_exchange_codes() + # Second call without cache + metadata.get_exchange_codes(use_cache=False) + + # API should be called twice + assert mock_requests.return_value.request.call_count == 2 + + def test_get_exchange_codes_api_error(self, metadata): + with patch("py_alpaca_api.stock.metadata.Requests") as mock_requests: + mock_requests.return_value.request.side_effect = Exception("API Error") + + with pytest.raises(APIRequestError) as exc_info: + metadata.get_exchange_codes() + + assert "Failed to get exchange codes" in str(exc_info.value) + + def test_get_condition_codes_trade(self, metadata, mock_condition_response): + with patch("py_alpaca_api.stock.metadata.Requests") as mock_requests: + mock_response = MagicMock() + mock_response.text = json.dumps(mock_condition_response) + mock_requests.return_value.request.return_value = mock_response + + result = metadata.get_condition_codes(ticktype="trade", tape="A") + + assert isinstance(result, dict) + assert result[""] == "Regular Sale" + assert result["4"] == "Derivatively Priced" + assert result["F"] == "Intermarket Sweep" + + mock_requests.return_value.request.assert_called_once_with( + method="GET", + url=f"{metadata.base_url}/conditions/trade", + headers=metadata.headers, + params={"tape": "A"}, + ) + + def test_get_condition_codes_quote(self, metadata): + mock_quote_conditions = { + "4": "On Demand Intra Day Auction", + "A": "Slow Quote Offer Side", + "B": "Slow Quote Bid Side", + "C": "Exchange Specific Quote Condition", + "D": "NASDAQ", + "E": "Manual Ask Automated Bid", + "F": "Manual Bid Automated Ask", + "G": "Manual Bid And Ask", + "H": "Fast Trading", + "I": "Pending", + "L": "Closed Quote", + "O": "Opening Quote Automated", + "R": "Regular Two Sided Open", + } + + with patch("py_alpaca_api.stock.metadata.Requests") as mock_requests: + mock_response = MagicMock() + mock_response.text = json.dumps(mock_quote_conditions) + mock_requests.return_value.request.return_value = mock_response + + result = metadata.get_condition_codes(ticktype="quote", tape="B") + + assert isinstance(result, dict) + assert result["A"] == "Slow Quote Offer Side" + assert result["B"] == "Slow Quote Bid Side" + + mock_requests.return_value.request.assert_called_once_with( + method="GET", + url=f"{metadata.base_url}/conditions/quote", + headers=metadata.headers, + params={"tape": "B"}, + ) + + def test_get_condition_codes_invalid_ticktype(self, metadata): + with pytest.raises(ValidationError, match="Invalid ticktype"): + metadata.get_condition_codes(ticktype="invalid") + + def test_get_condition_codes_invalid_tape(self, metadata): + with pytest.raises(ValidationError, match="Invalid tape"): + metadata.get_condition_codes(tape="X") + + def test_get_condition_codes_with_cache(self, metadata, mock_condition_response): + with patch("py_alpaca_api.stock.metadata.Requests") as mock_requests: + mock_response = MagicMock() + mock_response.text = json.dumps(mock_condition_response) + mock_requests.return_value.request.return_value = mock_response + + # First call should hit API + result1 = metadata.get_condition_codes(ticktype="trade", tape="A") + # Second call should use cache + result2 = metadata.get_condition_codes(ticktype="trade", tape="A") + + assert result1 == result2 + # API should only be called once due to caching + assert mock_requests.return_value.request.call_count == 1 + + def test_get_all_condition_codes(self, metadata): + mock_conditions = {"": "Regular Sale", "4": "Derivatively Priced"} + + with patch("py_alpaca_api.stock.metadata.Requests") as mock_requests: + mock_response = MagicMock() + mock_response.text = json.dumps(mock_conditions) + mock_requests.return_value.request.return_value = mock_response + + result = metadata.get_all_condition_codes() + + assert isinstance(result, dict) + assert "trade" in result + assert "quote" in result + assert "A" in result["trade"] + assert "B" in result["trade"] + assert "C" in result["trade"] + assert "A" in result["quote"] + assert "B" in result["quote"] + assert "C" in result["quote"] + + # Should call API 6 times (2 ticktypes * 3 tapes) + assert mock_requests.return_value.request.call_count == 6 + + def test_clear_cache(self, metadata, mock_exchange_response): + with patch("py_alpaca_api.stock.metadata.Requests") as mock_requests: + mock_response = MagicMock() + mock_response.text = json.dumps(mock_exchange_response) + mock_requests.return_value.request.return_value = mock_response + + # Load data into cache + metadata.get_exchange_codes() + assert metadata._exchange_cache is not None + + # Clear cache + metadata.clear_cache() + assert metadata._exchange_cache is None + assert metadata._condition_cache == {} + + def test_lookup_exchange(self, metadata, mock_exchange_response): + with patch("py_alpaca_api.stock.metadata.Requests") as mock_requests: + mock_response = MagicMock() + mock_response.text = json.dumps(mock_exchange_response) + mock_requests.return_value.request.return_value = mock_response + + # Test valid code + result = metadata.lookup_exchange("Q") + assert result == "NASDAQ" + + # Test invalid code + result = metadata.lookup_exchange("ZZ") + assert result is None + + def test_lookup_condition(self, metadata, mock_condition_response): + with patch("py_alpaca_api.stock.metadata.Requests") as mock_requests: + mock_response = MagicMock() + mock_response.text = json.dumps(mock_condition_response) + mock_requests.return_value.request.return_value = mock_response + + # Test valid code + result = metadata.lookup_condition("F", ticktype="trade", tape="A") + assert result == "Intermarket Sweep" + + # Test invalid code + result = metadata.lookup_condition("ZZ", ticktype="trade", tape="A") + assert result is None + + def test_get_condition_codes_api_error(self, metadata): + with patch("py_alpaca_api.stock.metadata.Requests") as mock_requests: + mock_requests.return_value.request.side_effect = Exception("API Error") + + with pytest.raises(APIRequestError) as exc_info: + metadata.get_condition_codes() + + assert "Failed to get condition codes" in str(exc_info.value) + + def test_get_condition_codes_empty_response(self, metadata): + with patch("py_alpaca_api.stock.metadata.Requests") as mock_requests: + mock_response = MagicMock() + mock_response.text = "null" + mock_requests.return_value.request.return_value = mock_response + + with pytest.raises(APIRequestError, match="No condition data returned"): + metadata.get_condition_codes() + + def test_get_exchange_codes_empty_response(self, metadata): + with patch("py_alpaca_api.stock.metadata.Requests") as mock_requests: + mock_response = MagicMock() + mock_response.text = "{}" + mock_requests.return_value.request.return_value = mock_response + + with pytest.raises(APIRequestError, match="No exchange data returned"): + metadata.get_exchange_codes() diff --git a/tests/test_stock/test_snapshots.py b/tests/test_stock/test_snapshots.py new file mode 100644 index 0000000..1532b4a --- /dev/null +++ b/tests/test_stock/test_snapshots.py @@ -0,0 +1,424 @@ +import json +from unittest.mock import MagicMock, patch + +import pytest + +from py_alpaca_api.exceptions import APIRequestError, ValidationError +from py_alpaca_api.models.snapshot_model import ( + BarModel, + SnapshotModel, + bar_class_from_dict, + snapshot_class_from_dict, +) +from py_alpaca_api.stock.snapshots import Snapshots + + +class TestSnapshots: + @pytest.fixture + def snapshots(self): + headers = { + "APCA-API-KEY-ID": "test_key", + "APCA-API-SECRET-KEY": "test_secret", + } + return Snapshots(headers=headers) + + @pytest.fixture + def mock_snapshot_response(self): + return { + "latestTrade": { + "t": "2025-01-14T10:30:00Z", + "p": 150.25, + "s": 100, + "c": ["@", "F"], + "x": "Q", + "z": "C", + }, + "latestQuote": { + "t": "2025-01-14T10:30:01Z", + "ap": 150.30, + "as": 100, + "bp": 150.20, + "bs": 200, + "ax": "Q", + "bx": "N", + "c": ["R"], + "z": "C", + }, + "minuteBar": { + "t": "2025-01-14T10:30:00Z", + "o": 150.10, + "h": 150.35, + "l": 150.05, + "c": 150.25, + "v": 10000, + "n": 250, + "vw": 150.20, + }, + "dailyBar": { + "t": "2025-01-14T00:00:00Z", + "o": 149.50, + "h": 151.00, + "l": 149.00, + "c": 150.25, + "v": 1000000, + "n": 25000, + "vw": 150.00, + }, + "prevDailyBar": { + "t": "2025-01-13T00:00:00Z", + "o": 148.50, + "h": 150.00, + "l": 148.00, + "c": 149.50, + "v": 1200000, + "n": 28000, + "vw": 149.25, + }, + } + + @pytest.fixture + def mock_snapshots_response(self): + # API returns symbols as top-level keys + return { + "AAPL": { + "latestTrade": { + "t": "2025-01-14T10:30:00Z", + "p": 150.25, + "s": 100, + "c": ["@"], + "x": "Q", + "z": "C", + }, + "latestQuote": { + "t": "2025-01-14T10:30:01Z", + "ap": 150.30, + "as": 100, + "bp": 150.20, + "bs": 200, + "ax": "Q", + "bx": "N", + "c": ["R"], + "z": "C", + }, + "minuteBar": { + "t": "2025-01-14T10:30:00Z", + "o": 150.10, + "h": 150.35, + "l": 150.05, + "c": 150.25, + "v": 10000, + "n": 250, + "vw": 150.20, + }, + "dailyBar": None, + "prevDailyBar": None, + }, + "MSFT": { + "latestTrade": { + "t": "2025-01-14T10:30:00Z", + "p": 380.50, + "s": 50, + "c": [], + "x": "N", + "z": "C", + }, + "latestQuote": { + "t": "2025-01-14T10:30:01Z", + "ap": 380.60, + "as": 50, + "bp": 380.40, + "bs": 100, + "ax": "N", + "bx": "P", + "c": [], + "z": "C", + }, + "minuteBar": { + "t": "2025-01-14T10:30:00Z", + "o": 380.25, + "h": 380.75, + "l": 380.00, + "c": 380.50, + "v": 5000, + "n": 150, + "vw": 380.40, + }, + "dailyBar": { + "t": "2025-01-14T00:00:00Z", + "o": 379.00, + "h": 381.00, + "l": 378.50, + "c": 380.50, + "v": 500000, + "n": 12000, + "vw": 380.00, + }, + "prevDailyBar": { + "t": "2025-01-13T00:00:00Z", + "o": 378.00, + "h": 380.00, + "l": 377.50, + "c": 379.00, + "v": 600000, + "n": 14000, + "vw": 378.75, + }, + }, + } + + def test_get_snapshot_valid_symbol(self, snapshots, mock_snapshot_response): + with patch("py_alpaca_api.stock.snapshots.Requests") as mock_requests: + mock_response = MagicMock() + mock_response.text = json.dumps(mock_snapshot_response) + mock_requests.return_value.request.return_value = mock_response + + result = snapshots.get_snapshot("AAPL") + + assert isinstance(result, SnapshotModel) + assert result.symbol == "AAPL" + assert result.latest_trade is not None + assert result.latest_trade.price == 150.25 + assert result.latest_quote is not None + assert result.latest_quote.ask == 150.30 + assert result.minute_bar is not None + assert result.minute_bar.close == 150.25 + assert result.daily_bar is not None + assert result.prev_daily_bar is not None + + def test_get_snapshot_with_feed(self, snapshots, mock_snapshot_response): + with patch("py_alpaca_api.stock.snapshots.Requests") as mock_requests: + mock_response = MagicMock() + mock_response.text = json.dumps(mock_snapshot_response) + mock_requests.return_value.request.return_value = mock_response + + snapshots.get_snapshot("AAPL", feed="sip") + + mock_requests.return_value.request.assert_called_once() + call_args = mock_requests.return_value.request.call_args + assert call_args.kwargs["params"]["feed"] == "sip" + + def test_get_snapshot_invalid_symbol(self, snapshots): + with pytest.raises(ValidationError, match="Symbol is required"): + snapshots.get_snapshot("") + + with pytest.raises(ValidationError, match="Symbol is required"): + snapshots.get_snapshot(None) + + def test_get_snapshot_invalid_feed(self, snapshots): + with pytest.raises(ValidationError, match="Invalid feed"): + snapshots.get_snapshot("AAPL", feed="invalid") + + def test_get_snapshot_api_error(self, snapshots): + with patch("py_alpaca_api.stock.snapshots.Requests") as mock_requests: + mock_requests.return_value.request.side_effect = Exception("API Error") + + with pytest.raises(APIRequestError) as exc_info: + snapshots.get_snapshot("AAPL") + assert "Failed to get snapshot" in str(exc_info.value) + + def test_get_snapshots_multiple_symbols(self, snapshots, mock_snapshots_response): + with patch("py_alpaca_api.stock.snapshots.Requests") as mock_requests: + mock_response = MagicMock() + mock_response.text = json.dumps(mock_snapshots_response) + mock_requests.return_value.request.return_value = mock_response + + result = snapshots.get_snapshots(["AAPL", "MSFT"]) + + assert isinstance(result, dict) + assert len(result) == 2 + assert "AAPL" in result + assert "MSFT" in result + assert isinstance(result["AAPL"], SnapshotModel) + assert isinstance(result["MSFT"], SnapshotModel) + assert result["AAPL"].symbol == "AAPL" + assert result["MSFT"].symbol == "MSFT" + + def test_get_snapshots_single_symbol_returns_list( + self, snapshots, mock_snapshots_response + ): + mock_single_response = {"AAPL": mock_snapshots_response["AAPL"]} + + with patch("py_alpaca_api.stock.snapshots.Requests") as mock_requests: + mock_response = MagicMock() + mock_response.text = json.dumps(mock_single_response) + mock_requests.return_value.request.return_value = mock_response + + result = snapshots.get_snapshots("AAPL") + + assert isinstance(result, list) + assert len(result) == 1 + assert isinstance(result[0], SnapshotModel) + assert result[0].symbol == "AAPL" + + def test_get_snapshots_comma_separated_string( + self, snapshots, mock_snapshots_response + ): + with patch("py_alpaca_api.stock.snapshots.Requests") as mock_requests: + mock_response = MagicMock() + mock_response.text = json.dumps(mock_snapshots_response) + mock_requests.return_value.request.return_value = mock_response + + result = snapshots.get_snapshots("AAPL,MSFT") + + assert isinstance(result, dict) + assert len(result) == 2 + assert "AAPL" in result + assert "MSFT" in result + + def test_get_snapshots_invalid_symbols(self, snapshots): + with pytest.raises(ValidationError, match="Symbols are required"): + snapshots.get_snapshots([]) + + with pytest.raises(ValidationError, match="Symbols are required"): + snapshots.get_snapshots(None) + + def test_get_snapshots_invalid_feed(self, snapshots): + with pytest.raises(ValidationError, match="Invalid feed"): + snapshots.get_snapshots(["AAPL"], feed="invalid") + + def test_get_snapshots_api_error(self, snapshots): + with patch("py_alpaca_api.stock.snapshots.Requests") as mock_requests: + mock_requests.return_value.request.side_effect = Exception("API Error") + + with pytest.raises(APIRequestError) as exc_info: + snapshots.get_snapshots(["AAPL", "MSFT"]) + assert "Failed to get snapshots" in str(exc_info.value) + + +class TestSnapshotModels: + def test_bar_class_from_dict(self): + # Use API field names as they come from the API + bar_data = { + "t": "2025-01-14T10:30:00Z", + "o": 150.10, + "h": 150.35, + "l": 150.05, + "c": 150.25, + "v": 10000, + "n": 250, + "vw": 150.20, + } + + bar = bar_class_from_dict(bar_data) + + assert isinstance(bar, BarModel) + assert bar.open == 150.10 + assert bar.high == 150.35 + assert bar.low == 150.05 + assert bar.close == 150.25 + assert bar.volume == 10000 + assert bar.trade_count == 250 + assert bar.vwap == 150.20 + + def test_bar_class_from_dict_minimal(self): + # Use API field names, minimal data + bar_data = { + "t": "2025-01-14T10:30:00Z", + "o": 150.10, + "h": 150.35, + "l": 150.05, + "c": 150.25, + "v": 10000, + } + + bar = bar_class_from_dict(bar_data) + + assert isinstance(bar, BarModel) + assert bar.trade_count is None + assert bar.vwap is None + + def test_snapshot_class_from_dict_full(self): + snapshot_data = { + "symbol": "AAPL", + "latestTrade": { + "t": "2025-01-14T10:30:00Z", + "p": 150.25, + "s": 100, + "c": ["@"], + "x": "Q", + "z": "C", + }, + "latestQuote": { + "t": "2025-01-14T10:30:01Z", + "ap": 150.30, + "as": 100, + "bp": 150.20, + "bs": 200, + "ax": "Q", + "bx": "N", + "c": ["R"], + "z": "C", + }, + "minuteBar": { + "t": "2025-01-14T10:30:00Z", + "o": 150.10, + "h": 150.35, + "l": 150.05, + "c": 150.25, + "v": 10000, + "n": 250, + "vw": 150.20, + }, + "dailyBar": { + "t": "2025-01-14T00:00:00Z", + "o": 149.50, + "h": 151.00, + "l": 149.00, + "c": 150.25, + "v": 1000000, + "n": 25000, + "vw": 150.00, + }, + "prevDailyBar": { + "t": "2025-01-13T00:00:00Z", + "o": 148.50, + "h": 150.00, + "l": 148.00, + "c": 149.50, + "v": 1200000, + "n": 28000, + "vw": 149.25, + }, + } + + snapshot = snapshot_class_from_dict(snapshot_data) + + assert isinstance(snapshot, SnapshotModel) + assert snapshot.symbol == "AAPL" + assert snapshot.latest_trade is not None + assert snapshot.latest_trade.price == 150.25 + assert snapshot.latest_quote is not None + assert snapshot.latest_quote.ask == 150.30 + assert snapshot.minute_bar is not None + assert snapshot.minute_bar.close == 150.25 + assert snapshot.daily_bar is not None + assert snapshot.daily_bar.close == 150.25 + assert snapshot.prev_daily_bar is not None + assert snapshot.prev_daily_bar.close == 149.50 + + def test_snapshot_class_from_dict_partial(self): + snapshot_data = { + "symbol": "AAPL", + "latestTrade": { + "t": "2025-01-14T10:30:00Z", + "p": 150.25, + "s": 100, + "c": [], + "x": "Q", + "z": "C", + }, + "latestQuote": None, + "minuteBar": None, + "dailyBar": None, + "prevDailyBar": None, + } + + snapshot = snapshot_class_from_dict(snapshot_data) + + assert isinstance(snapshot, SnapshotModel) + assert snapshot.symbol == "AAPL" + assert snapshot.latest_trade is not None + assert snapshot.latest_quote is None + assert snapshot.minute_bar is None + assert snapshot.daily_bar is None + assert snapshot.prev_daily_bar is None diff --git a/tests/test_stock/test_trades.py b/tests/test_stock/test_trades.py new file mode 100644 index 0000000..64169d0 --- /dev/null +++ b/tests/test_stock/test_trades.py @@ -0,0 +1,330 @@ +import os +from unittest.mock import MagicMock, patch + +import pytest + +from py_alpaca_api import PyAlpacaAPI +from py_alpaca_api.exceptions import APIRequestError, ValidationError +from py_alpaca_api.models.trade_model import ( + TradeModel, + TradesResponse, + trade_class_from_dict, +) + + +@pytest.fixture +def alpaca(): + """Create PyAlpacaAPI instance for testing.""" + return PyAlpacaAPI( + api_key=os.environ.get("ALPACA_API_KEY", "test_key"), + api_secret=os.environ.get("ALPACA_SECRET_KEY", "test_secret"), + api_paper=True, + ) + + +@pytest.fixture +def mock_trade_response(): + """Sample trade response data.""" + return { + "t": "2024-01-15T14:30:00Z", + "x": "V", # IEX exchange + "p": 150.25, + "s": 100, + "c": ["@", "I"], + "i": 12345, + "z": "C", + } + + +@pytest.fixture +def mock_trades_response(): + """Sample multiple trades response.""" + return { + "trades": [ + { + "t": "2024-01-15T14:30:00Z", + "x": "V", + "p": 150.25, + "s": 100, + "c": ["@"], + "i": 12345, + "z": "C", + }, + { + "t": "2024-01-15T14:30:01Z", + "x": "K", + "p": 150.26, + "s": 200, + "c": ["F"], + "i": 12346, + "z": "C", + }, + ], + "symbol": "AAPL", + "next_page_token": "token123", + } + + +class TestTrades: + """Test suite for Trades functionality.""" + + def test_trade_model_creation(self, mock_trade_response): + """Test creating a TradeModel from dict.""" + trade = trade_class_from_dict(mock_trade_response, "AAPL") + + assert isinstance(trade, TradeModel) + assert trade.timestamp == "2024-01-15T14:30:00Z" + assert trade.symbol == "AAPL" + assert trade.exchange == "V" + assert trade.price == 150.25 + assert trade.size == 100 + assert trade.conditions == ["@", "I"] + assert trade.id == 12345 + assert trade.tape == "C" + + @patch("py_alpaca_api.http.requests.Requests.request") + def test_get_trades_success(self, mock_request, alpaca, mock_trades_response): + """Test successful retrieval of historical trades.""" + # Setup mock response + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.text = '{"trades": [{"t": "2024-01-15T14:30:00Z", "x": "V", "p": 150.25, "s": 100, "c": ["@"], "i": 12345, "z": "C"}], "symbol": "AAPL", "next_page_token": null}' + mock_request.return_value = mock_response + + # Call method + result = alpaca.stock.trades.get_trades( + symbol="AAPL", + start="2024-01-15T14:00:00Z", + end="2024-01-15T15:00:00Z", + limit=100, + ) + + # Verify + assert isinstance(result, TradesResponse) + assert len(result.trades) == 1 + assert result.symbol == "AAPL" + assert result.trades[0].price == 150.25 + + def test_get_trades_validation(self, alpaca): + """Test validation for get_trades parameters.""" + # Test missing symbol + with pytest.raises(ValidationError, match="Symbol is required"): + alpaca.stock.trades.get_trades( + symbol="", + start="2024-01-15T14:00:00Z", + end="2024-01-15T15:00:00Z", + ) + + # Test invalid limit + with pytest.raises(ValidationError, match="Limit must be between"): + alpaca.stock.trades.get_trades( + symbol="AAPL", + start="2024-01-15T14:00:00Z", + end="2024-01-15T15:00:00Z", + limit=10001, + ) + + # Test invalid date format + with pytest.raises(ValidationError, match="Invalid date format"): + alpaca.stock.trades.get_trades( + symbol="AAPL", + start="2024-01-15", # Missing time + end="2024-01-15T15:00:00Z", + ) + + @patch("py_alpaca_api.http.requests.Requests.request") + def test_get_latest_trade_success(self, mock_request, alpaca): + """Test successful retrieval of latest trade.""" + # Setup mock response + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.text = '{"trades": {"AAPL": {"t": "2024-01-15T14:30:00Z", "x": "V", "p": 150.25, "s": 100, "c": ["@"], "i": 12345, "z": "C"}}}' + mock_request.return_value = mock_response + + # Call method + result = alpaca.stock.trades.get_latest_trade("AAPL") + + # Verify + assert isinstance(result, TradeModel) + assert result.symbol == "AAPL" + assert result.price == 150.25 + assert result.exchange == "V" + + @patch("py_alpaca_api.http.requests.Requests.request") + def test_get_latest_trade_not_found(self, mock_request, alpaca): + """Test handling when symbol not found in latest trades.""" + # Setup mock response + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.text = '{"trades": {}}' + mock_request.return_value = mock_response + + # Call method and expect error + with pytest.raises(APIRequestError, match="No trade data found"): + alpaca.stock.trades.get_latest_trade("INVALID") + + @patch("py_alpaca_api.http.requests.Requests.request") + def test_get_trades_multi_success(self, mock_request, alpaca): + """Test successful retrieval of trades for multiple symbols.""" + # Setup mock response + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.text = """{ + "trades": { + "AAPL": [{"t": "2024-01-15T14:30:00Z", "x": "V", "p": 150.25, "s": 100, "c": ["@"], "i": 12345, "z": "C"}], + "MSFT": [{"t": "2024-01-15T14:30:00Z", "x": "K", "p": 380.50, "s": 50, "c": ["F"], "i": 12346, "z": "C"}] + }, + "next_page_token": null + }""" + mock_request.return_value = mock_response + + # Call method + result = alpaca.stock.trades.get_trades_multi( + symbols=["AAPL", "MSFT"], + start="2024-01-15T14:00:00Z", + end="2024-01-15T15:00:00Z", + ) + + # Verify + assert isinstance(result, dict) + assert "AAPL" in result + assert "MSFT" in result + assert isinstance(result["AAPL"], TradesResponse) + assert len(result["AAPL"].trades) == 1 + assert result["AAPL"].trades[0].price == 150.25 + assert result["MSFT"].trades[0].price == 380.50 + + def test_get_trades_multi_validation(self, alpaca): + """Test validation for multi-symbol trades.""" + # Test empty symbols list + with pytest.raises(ValidationError, match="At least one symbol"): + alpaca.stock.trades.get_trades_multi( + symbols=[], + start="2024-01-15T14:00:00Z", + end="2024-01-15T15:00:00Z", + ) + + # Test too many symbols + with pytest.raises(ValidationError, match="Maximum 100 symbols"): + alpaca.stock.trades.get_trades_multi( + symbols=["AAPL"] * 101, + start="2024-01-15T14:00:00Z", + end="2024-01-15T15:00:00Z", + ) + + @patch("py_alpaca_api.http.requests.Requests.request") + def test_get_latest_trades_multi(self, mock_request, alpaca): + """Test getting latest trades for multiple symbols.""" + # Setup mock response + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.text = """{ + "trades": { + "AAPL": {"t": "2024-01-15T14:30:00Z", "x": "V", "p": 150.25, "s": 100, "c": ["@"], "i": 12345, "z": "C"}, + "MSFT": {"t": "2024-01-15T14:30:01Z", "x": "K", "p": 380.50, "s": 50, "c": ["F"], "i": 12346, "z": "C"} + } + }""" + mock_request.return_value = mock_response + + # Call method + result = alpaca.stock.trades.get_latest_trades_multi(["AAPL", "MSFT"]) + + # Verify + assert isinstance(result, dict) + assert "AAPL" in result + assert "MSFT" in result + assert isinstance(result["AAPL"], TradeModel) + assert result["AAPL"].price == 150.25 + assert result["MSFT"].price == 380.50 + + @patch("py_alpaca_api.http.requests.Requests.request") + def test_get_all_trades_pagination(self, mock_request, alpaca): + """Test get_all_trades with pagination.""" + # Setup mock responses for pagination + responses = [ + MagicMock( + status_code=200, + text='{"trades": [{"t": "2024-01-15T14:30:00Z", "x": "V", "p": 150.25, "s": 100, "c": ["@"], "i": 1, "z": "C"}], "symbol": "AAPL", "next_page_token": "page2"}', + ), + MagicMock( + status_code=200, + text='{"trades": [{"t": "2024-01-15T14:31:00Z", "x": "K", "p": 150.30, "s": 200, "c": ["F"], "i": 2, "z": "C"}], "symbol": "AAPL", "next_page_token": null}', + ), + ] + mock_request.side_effect = responses + + # Call method + result = alpaca.stock.trades.get_all_trades( + symbol="AAPL", + start="2024-01-15T14:00:00Z", + end="2024-01-15T15:00:00Z", + ) + + # Verify + assert isinstance(result, list) + assert len(result) == 2 + assert result[0].price == 150.25 + assert result[1].price == 150.30 + + def test_feed_parameter(self, alpaca): + """Test that feed parameter is properly handled.""" + with patch("py_alpaca_api.http.requests.Requests.request") as mock_request: + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.text = '{"trades": [], "symbol": "AAPL"}' + mock_request.return_value = mock_response + + # Call with feed parameter + alpaca.stock.trades.get_trades( + symbol="AAPL", + start="2024-01-15T14:00:00Z", + end="2024-01-15T15:00:00Z", + feed="sip", + ) + + # Verify feed was included in params + call_args = mock_request.call_args + assert call_args[1]["params"]["feed"] == "sip" + + @patch("py_alpaca_api.http.requests.Requests.request") + def test_api_error_handling(self, mock_request, alpaca): + """Test handling of API errors.""" + # Setup mock error response + mock_response = MagicMock() + mock_response.status_code = 403 + mock_response.text = "Forbidden: Subscription required" + mock_request.return_value = mock_response + + # Call method and expect error + with pytest.raises(APIRequestError) as exc_info: + alpaca.stock.trades.get_trades( + symbol="AAPL", + start="2024-01-15T14:00:00Z", + end="2024-01-15T15:00:00Z", + ) + + assert exc_info.value.status_code == 403 + assert "Failed to retrieve trades" in str(exc_info.value) + + def test_trades_response_model(self): + """Test TradesResponse model creation.""" + trades = [ + TradeModel( + timestamp="2024-01-15T14:30:00Z", + symbol="AAPL", + exchange="V", + price=150.25, + size=100, + conditions=["@"], + id=12345, + tape="C", + ) + ] + + response = TradesResponse( + trades=trades, symbol="AAPL", next_page_token="token123" + ) + + assert response.symbol == "AAPL" + assert len(response.trades) == 1 + assert response.next_page_token == "token123" diff --git a/tests/test_stock/test_trades_live.py b/tests/test_stock/test_trades_live.py new file mode 100644 index 0000000..bc8c206 --- /dev/null +++ b/tests/test_stock/test_trades_live.py @@ -0,0 +1,342 @@ +"""Integration tests for Trades API with live data. + +These tests require valid Alpaca API credentials and will make real API calls. +Run with: ./test.sh +""" + +import os +from datetime import datetime, timedelta + +import pytest + +from py_alpaca_api import PyAlpacaAPI +from py_alpaca_api.exceptions import APIRequestError, ValidationError +from py_alpaca_api.models.trade_model import TradeModel, TradesResponse + +# Skip all tests if no API credentials +pytestmark = pytest.mark.skipif( + not os.environ.get("ALPACA_API_KEY") or not os.environ.get("ALPACA_SECRET_KEY"), + reason="API credentials not available", +) + + +@pytest.fixture +def alpaca(): + """Create PyAlpacaAPI instance with real credentials.""" + return PyAlpacaAPI( + api_key=os.environ.get("ALPACA_API_KEY"), + api_secret=os.environ.get("ALPACA_SECRET_KEY"), + api_paper=True, + ) + + +class TestTradesLive: + """Integration tests for Trades API with live data.""" + + def test_get_trades_historical(self, alpaca): + """Test retrieving historical trades for a symbol.""" + # Use a recent trading day + end_time = datetime.now() + # Go back to previous trading day to ensure market was open + if end_time.weekday() == 0: # Monday + start_time = end_time - timedelta(days=3) # Friday + elif end_time.weekday() == 6: # Sunday + start_time = end_time - timedelta(days=2) # Friday + else: + start_time = end_time - timedelta(days=1) + + # Format times in RFC-3339 + start = start_time.replace(hour=14, minute=0, second=0).isoformat() + "Z" + end = start_time.replace(hour=14, minute=30, second=0).isoformat() + "Z" + + try: + result = alpaca.stock.trades.get_trades( + symbol="AAPL", + start=start, + end=end, + limit=10, + ) + + assert isinstance(result, TradesResponse) + assert result.symbol == "AAPL" + + if result.trades: + # Check first trade + trade = result.trades[0] + assert isinstance(trade, TradeModel) + assert trade.symbol == "AAPL" + assert trade.price > 0 + assert trade.size > 0 + assert trade.exchange + assert trade.timestamp + + print(f"\nFound {len(result.trades)} trades for AAPL") + print( + f"First trade: ${trade.price} x {trade.size} at {trade.timestamp}" + ) + + except APIRequestError as e: + if e.status_code in [403, 429]: + pytest.skip(f"API rate limit or subscription issue: {e}") + raise + + def test_get_latest_trade(self, alpaca): + """Test retrieving the latest trade for a symbol.""" + try: + trade = alpaca.stock.trades.get_latest_trade("AAPL") + + assert isinstance(trade, TradeModel) + assert trade.symbol == "AAPL" + assert trade.price > 0 + assert trade.size > 0 + assert trade.exchange + assert trade.timestamp + + print(f"\nLatest AAPL trade: ${trade.price} x {trade.size}") + print(f" Exchange: {trade.exchange}") + print(f" Time: {trade.timestamp}") + + except APIRequestError as e: + if e.status_code in [403, 429]: + pytest.skip(f"API rate limit or subscription issue: {e}") + raise + + def test_get_trades_with_feed(self, alpaca): + """Test retrieving trades with different feed types.""" + end_time = datetime.now() + start_time = end_time - timedelta(hours=1) + + start = start_time.isoformat() + "Z" + end = end_time.isoformat() + "Z" + + # Try IEX feed (usually available for free tier) + try: + result = alpaca.stock.trades.get_trades( + symbol="AAPL", + start=start, + end=end, + limit=5, + feed="iex", + ) + + assert isinstance(result, TradesResponse) + print(f"\nIEX feed returned {len(result.trades)} trades") + + except APIRequestError as e: + if e.status_code == 403: + print("\nIEX feed not available (subscription required)") + elif e.status_code == 429: + pytest.skip("Rate limit reached") + else: + raise + + def test_get_trades_multi(self, alpaca): + """Test retrieving trades for multiple symbols.""" + end_time = datetime.now() + start_time = end_time - timedelta(hours=1) + + start = start_time.isoformat() + "Z" + end = end_time.isoformat() + "Z" + + try: + result = alpaca.stock.trades.get_trades_multi( + symbols=["AAPL", "MSFT", "GOOGL"], + start=start, + end=end, + limit=5, + ) + + assert isinstance(result, dict) + + for symbol in ["AAPL", "MSFT", "GOOGL"]: + if symbol in result: + assert isinstance(result[symbol], TradesResponse) + assert result[symbol].symbol == symbol + + if result[symbol].trades: + print(f"\n{symbol}: {len(result[symbol].trades)} trades") + first_trade = result[symbol].trades[0] + print(f" First trade: ${first_trade.price}") + + except APIRequestError as e: + if e.status_code in [403, 429]: + pytest.skip(f"API rate limit or subscription issue: {e}") + raise + + def test_get_latest_trades_multi(self, alpaca): + """Test getting latest trades for multiple symbols.""" + symbols = ["AAPL", "MSFT", "GOOGL", "AMZN", "TSLA"] + + try: + result = alpaca.stock.trades.get_latest_trades_multi(symbols) + + assert isinstance(result, dict) + + print("\nLatest trades for multiple symbols:") + for symbol in symbols: + if symbol in result: + trade = result[symbol] + assert isinstance(trade, TradeModel) + assert trade.symbol == symbol + assert trade.price > 0 + + print(f" {symbol}: ${trade.price:,.2f} x {trade.size}") + + except APIRequestError as e: + if e.status_code in [403, 429]: + pytest.skip(f"API rate limit or subscription issue: {e}") + raise + + def test_pagination(self, alpaca): + """Test pagination with large result sets.""" + # Use a wider time range to get more trades + end_time = datetime.now() + start_time = end_time - timedelta(days=1) + + start = start_time.isoformat() + "Z" + end = end_time.isoformat() + "Z" + + try: + # First request with small limit + first_page = alpaca.stock.trades.get_trades( + symbol="SPY", # SPY typically has high volume + start=start, + end=end, + limit=100, + ) + + assert isinstance(first_page, TradesResponse) + + if first_page.next_page_token: + # Get next page + second_page = alpaca.stock.trades.get_trades( + symbol="SPY", + start=start, + end=end, + limit=100, + page_token=first_page.next_page_token, + ) + + assert isinstance(second_page, TradesResponse) + print("\nPagination test:") + print(f" First page: {len(first_page.trades)} trades") + print(f" Second page: {len(second_page.trades)} trades") + print(f" Has more pages: {second_page.next_page_token is not None}") + else: + print(f"\nOnly one page of results ({len(first_page.trades)} trades)") + + except APIRequestError as e: + if e.status_code in [403, 429]: + pytest.skip(f"API rate limit or subscription issue: {e}") + raise + + def test_get_all_trades(self, alpaca): + """Test getting all trades with automatic pagination.""" + # Use a short time window to avoid too much data + end_time = datetime.now() + start_time = end_time - timedelta(minutes=5) + + start = start_time.isoformat() + "Z" + end = end_time.isoformat() + "Z" + + try: + all_trades = alpaca.stock.trades.get_all_trades( + symbol="AAPL", + start=start, + end=end, + ) + + assert isinstance(all_trades, list) + + if all_trades: + assert all(isinstance(trade, TradeModel) for trade in all_trades) + assert all(trade.symbol == "AAPL" for trade in all_trades) + + print(f"\nRetrieved {len(all_trades)} total trades across all pages") + + # Check trades are in chronological order + # Parse timestamps to handle different precision levels + parsed_timestamps = [ + datetime.fromisoformat(trade.timestamp.replace("Z", "+00:00")) + for trade in all_trades + ] + assert parsed_timestamps == sorted( + parsed_timestamps + ), "Trades are not in chronological order" + print(" Trades are in chronological order βœ“") + + except APIRequestError as e: + if e.status_code in [403, 429]: + pytest.skip(f"API rate limit or subscription issue: {e}") + raise + + def test_error_handling(self, alpaca): + """Test error handling for invalid requests.""" + # Test with invalid symbol + with pytest.raises(APIRequestError): + alpaca.stock.trades.get_latest_trade("INVALID_SYMBOL_XYZ") + + # Test with invalid date range + with pytest.raises(ValidationError): + alpaca.stock.trades.get_trades( + symbol="AAPL", + start="invalid_date", + end="2024-01-15T15:00:00Z", + ) + + # Test with too many symbols + with pytest.raises(ValidationError): + alpaca.stock.trades.get_trades_multi( + symbols=["AAPL"] * 101, # Over 100 symbol limit + start="2024-01-15T14:00:00Z", + end="2024-01-15T15:00:00Z", + ) + + def test_trade_conditions(self, alpaca): + """Test that trade conditions are properly captured.""" + end_time = datetime.now() + start_time = end_time - timedelta(hours=1) + + start = start_time.isoformat() + "Z" + end = end_time.isoformat() + "Z" + + try: + result = alpaca.stock.trades.get_trades( + symbol="SPY", + start=start, + end=end, + limit=100, + ) + + if result.trades: + # Check for trades with conditions + trades_with_conditions = [t for t in result.trades if t.conditions] + + if trades_with_conditions: + print( + f"\nFound {len(trades_with_conditions)} trades with conditions" + ) + # Print some example conditions + unique_conditions = set() + for trade in trades_with_conditions: + if trade.conditions: + unique_conditions.update(trade.conditions) + + print(f" Unique conditions seen: {sorted(unique_conditions)}") + + except APIRequestError as e: + if e.status_code in [403, 429]: + pytest.skip(f"API rate limit or subscription issue: {e}") + raise + + def test_asof_parameter(self, alpaca): + """Test the as-of parameter for historical point-in-time data.""" + # Skip this test as asof requires historical dates and proper subscription + pytest.skip( + "As-of parameter requires specific historical dates and subscription" + ) + + +if __name__ == "__main__": + # Allow running this file directly for testing + pytest.main([__file__, "-v"]) diff --git a/tests/test_trading/test_account_config.py b/tests/test_trading/test_account_config.py new file mode 100644 index 0000000..ef0456c --- /dev/null +++ b/tests/test_trading/test_account_config.py @@ -0,0 +1,252 @@ +import json +from unittest.mock import MagicMock, patch + +import pytest + +from py_alpaca_api.exceptions import APIRequestError +from py_alpaca_api.models.account_config_model import ( + AccountConfigModel, + account_config_class_from_dict, +) +from py_alpaca_api.trading.account import Account + + +class TestAccountConfig: + @pytest.fixture + def account(self): + headers = { + "APCA-API-KEY-ID": "test_key", + "APCA-API-SECRET-KEY": "test_secret", + } + base_url = "https://paper-api.alpaca.markets/v2" + return Account(headers=headers, base_url=base_url) + + @pytest.fixture + def mock_config_response(self): + return { + "dtbp_check": "entry", + "fractional_trading": True, + "max_margin_multiplier": "4", + "no_shorting": False, + "pdt_check": "entry", + "ptp_no_exception_entry": False, + "suspend_trade": False, + "trade_confirm_email": "all", + } + + def test_get_configuration_success(self, account, mock_config_response): + with patch("py_alpaca_api.trading.account.Requests") as mock_requests: + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.text = json.dumps(mock_config_response) + mock_requests.return_value.request.return_value = mock_response + + result = account.get_configuration() + + assert isinstance(result, AccountConfigModel) + assert result.dtbp_check == "entry" + assert result.fractional_trading is True + assert result.max_margin_multiplier == "4" + assert result.no_shorting is False + assert result.pdt_check == "entry" + assert result.ptp_no_exception_entry is False + assert result.suspend_trade is False + assert result.trade_confirm_email == "all" + + mock_requests.return_value.request.assert_called_once_with( + "GET", + f"{account.base_url}/account/configurations", + headers=account.headers, + ) + + def test_get_configuration_failure(self, account): + with patch("py_alpaca_api.trading.account.Requests") as mock_requests: + mock_response = MagicMock() + mock_response.status_code = 401 + mock_response.text = "Unauthorized" + mock_requests.return_value.request.return_value = mock_response + + with pytest.raises(APIRequestError) as exc_info: + account.get_configuration() + + assert exc_info.value.status_code == 401 + assert "Failed to retrieve account configuration" in str(exc_info.value) + + def test_update_configuration_single_param(self, account, mock_config_response): + mock_config_response["suspend_trade"] = True + + with patch("py_alpaca_api.trading.account.Requests") as mock_requests: + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.text = json.dumps(mock_config_response) + mock_requests.return_value.request.return_value = mock_response + + result = account.update_configuration(suspend_trade=True) + + assert isinstance(result, AccountConfigModel) + assert result.suspend_trade is True + + mock_requests.return_value.request.assert_called_once_with( + "PATCH", + f"{account.base_url}/account/configurations", + headers=account.headers, + json={"suspend_trade": True}, + ) + + def test_update_configuration_multiple_params(self, account, mock_config_response): + mock_config_response["no_shorting"] = True + mock_config_response["trade_confirm_email"] = "none" + + with patch("py_alpaca_api.trading.account.Requests") as mock_requests: + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.text = json.dumps(mock_config_response) + mock_requests.return_value.request.return_value = mock_response + + result = account.update_configuration( + no_shorting=True, trade_confirm_email="none" + ) + + assert isinstance(result, AccountConfigModel) + assert result.no_shorting is True + assert result.trade_confirm_email == "none" + + mock_requests.return_value.request.assert_called_once_with( + "PATCH", + f"{account.base_url}/account/configurations", + headers=account.headers, + json={"no_shorting": True, "trade_confirm_email": "none"}, + ) + + def test_update_configuration_all_params(self, account): + updated_config = { + "dtbp_check": "both", + "fractional_trading": False, + "max_margin_multiplier": "2", + "no_shorting": True, + "pdt_check": "exit", + "ptp_no_exception_entry": True, + "suspend_trade": True, + "trade_confirm_email": "none", + } + + with patch("py_alpaca_api.trading.account.Requests") as mock_requests: + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.text = json.dumps(updated_config) + mock_requests.return_value.request.return_value = mock_response + + result = account.update_configuration( + dtbp_check="both", + fractional_trading=False, + max_margin_multiplier="2", + no_shorting=True, + pdt_check="exit", + ptp_no_exception_entry=True, + suspend_trade=True, + trade_confirm_email="none", + ) + + assert isinstance(result, AccountConfigModel) + assert result.dtbp_check == "both" + assert result.fractional_trading is False + assert result.max_margin_multiplier == "2" + assert result.no_shorting is True + assert result.pdt_check == "exit" + assert result.ptp_no_exception_entry is True + assert result.suspend_trade is True + assert result.trade_confirm_email == "none" + + def test_update_configuration_invalid_dtbp_check(self, account): + with pytest.raises(ValueError, match="dtbp_check must be one of"): + account.update_configuration(dtbp_check="invalid") + + def test_update_configuration_invalid_pdt_check(self, account): + with pytest.raises(ValueError, match="pdt_check must be one of"): + account.update_configuration(pdt_check="invalid") + + def test_update_configuration_invalid_margin_multiplier(self, account): + with pytest.raises(ValueError, match="max_margin_multiplier must be one of"): + account.update_configuration(max_margin_multiplier="3") + + def test_update_configuration_invalid_trade_confirm_email(self, account): + with pytest.raises(ValueError, match="trade_confirm_email must be one of"): + account.update_configuration(trade_confirm_email="some") + + def test_update_configuration_no_params(self, account): + with pytest.raises( + ValueError, match="At least one configuration parameter must be provided" + ): + account.update_configuration() + + def test_update_configuration_failure(self, account): + with patch("py_alpaca_api.trading.account.Requests") as mock_requests: + mock_response = MagicMock() + mock_response.status_code = 400 + mock_response.text = "Bad Request" + mock_requests.return_value.request.return_value = mock_response + + with pytest.raises(APIRequestError) as exc_info: + account.update_configuration(suspend_trade=True) + + assert exc_info.value.status_code == 400 + assert "Failed to update account configuration" in str(exc_info.value) + + +class TestAccountConfigModel: + def test_account_config_from_dict(self): + data = { + "dtbp_check": "both", + "fractional_trading": True, + "max_margin_multiplier": "2", + "no_shorting": True, + "pdt_check": "exit", + "ptp_no_exception_entry": True, + "suspend_trade": False, + "trade_confirm_email": "none", + } + + config = account_config_class_from_dict(data) + + assert isinstance(config, AccountConfigModel) + assert config.dtbp_check == "both" + assert config.fractional_trading is True + assert config.max_margin_multiplier == "2" + assert config.no_shorting is True + assert config.pdt_check == "exit" + assert config.ptp_no_exception_entry is True + assert config.suspend_trade is False + assert config.trade_confirm_email == "none" + + def test_account_config_from_dict_with_defaults(self): + # Test with empty dict to verify defaults + data = {} + + config = account_config_class_from_dict(data) + + assert isinstance(config, AccountConfigModel) + assert config.dtbp_check == "entry" + assert config.fractional_trading is False + assert config.max_margin_multiplier == "1" + assert config.no_shorting is False + assert config.pdt_check == "entry" + assert config.ptp_no_exception_entry is False + assert config.suspend_trade is False + assert config.trade_confirm_email == "all" + + def test_account_config_from_dict_partial(self): + data = { + "dtbp_check": "exit", + "fractional_trading": True, + "trade_confirm_email": "none", + } + + config = account_config_class_from_dict(data) + + assert config.dtbp_check == "exit" + assert config.fractional_trading is True + assert config.trade_confirm_email == "none" + # Check defaults for missing fields + assert config.max_margin_multiplier == "1" + assert config.no_shorting is False + assert config.pdt_check == "entry" diff --git a/tests/test_trading/test_corporate_actions.py b/tests/test_trading/test_corporate_actions.py new file mode 100644 index 0000000..46c770e --- /dev/null +++ b/tests/test_trading/test_corporate_actions.py @@ -0,0 +1,344 @@ +import os +from datetime import datetime, timedelta +from unittest.mock import MagicMock, patch + +import pytest + +from py_alpaca_api import PyAlpacaAPI +from py_alpaca_api.exceptions import APIRequestError, ValidationError +from py_alpaca_api.models.corporate_action_model import ( + CorporateActionModel, + DividendModel, + MergerModel, + SpinoffModel, + SplitModel, + corporate_action_class_from_dict, +) + + +@pytest.fixture +def alpaca(): + """Create PyAlpacaAPI instance for testing.""" + return PyAlpacaAPI( + api_key=os.environ.get("ALPACA_API_KEY", "test_key"), + api_secret=os.environ.get("ALPACA_SECRET_KEY", "test_secret"), + api_paper=True, + ) + + +@pytest.fixture +def mock_dividend_response(): + """Sample dividend corporate action response.""" + return { + "id": "123456", + "corporate_action_id": "CA123", + "ca_type": "dividend", + "ca_sub_type": "cash", + "initiating_symbol": "AAPL", + "initiating_original_cusip": "037833100", + "declaration_date": "2024-01-15", + "ex_date": "2024-02-01", + "record_date": "2024-02-02", + "payable_date": "2024-02-15", + "cash": 0.96, + "cash_amount": 0.96, + "dividend_type": "quarterly", + "frequency": 4, + } + + +@pytest.fixture +def mock_split_response(): + """Sample stock split corporate action response.""" + return { + "id": "789012", + "corporate_action_id": "CA789", + "ca_type": "split", + "ca_sub_type": "stock_split", + "initiating_symbol": "TSLA", + "initiating_original_cusip": "88160R101", + "declaration_date": "2024-01-10", + "ex_date": "2024-01-25", + "split_from": 1.0, + "split_to": 3.0, + } + + +@pytest.fixture +def mock_merger_response(): + """Sample merger corporate action response.""" + return { + "id": "345678", + "corporate_action_id": "CA345", + "ca_type": "merger", + "ca_sub_type": "acquisition", + "initiating_symbol": "TARGET", + "target_symbol": "TARGET", + "acquirer_symbol": "BUYER", + "acquirer_cusip": "123456789", + "declaration_date": "2024-01-05", + "ex_date": "2024-03-01", + "cash_rate": 0.5, + "stock_rate": 1.2, + } + + +@pytest.fixture +def mock_spinoff_response(): + """Sample spinoff corporate action response.""" + return { + "id": "901234", + "corporate_action_id": "CA901", + "ca_type": "spinoff", + "ca_sub_type": "spinoff", + "initiating_symbol": "PARENT", + "new_symbol": "CHILD", + "new_cusip": "987654321", + "declaration_date": "2024-01-20", + "ex_date": "2024-02-15", + "ratio": 0.25, + } + + +class TestCorporateActions: + """Test suite for Corporate Actions functionality.""" + + def test_get_announcements_success(self, alpaca): + """Test successful retrieval of corporate action announcements.""" + with patch.object( + alpaca.trading.corporate_actions, "get_announcements" + ) as mock_get: + # Setup mock response + mock_get.return_value = [ + DividendModel( + id="123", + corporate_action_id="CA123", + ca_type="dividend", + ca_sub_type="cash", + initiating_symbol="AAPL", + initiating_original_cusip="037833100", + target_symbol=None, + target_original_cusip=None, + declaration_date="2024-01-15", + ex_date="2024-02-01", + record_date="2024-02-02", + payable_date="2024-02-15", + cash=0.96, + old_rate=None, + new_rate=None, + cash_amount=0.96, + dividend_type="quarterly", + frequency=4, + ) + ] + + # Call method + result = alpaca.trading.corporate_actions.get_announcements( + since="2024-01-01", + until="2024-03-31", + ca_types=["dividend"], + symbol="AAPL", + ) + + # Verify + assert len(result) == 1 + assert isinstance(result[0], DividendModel) + assert result[0].initiating_symbol == "AAPL" + assert result[0].cash_amount == 0.96 + + def test_get_announcements_date_validation(self, alpaca): + """Test date range validation for get_announcements.""" + # Test date range exceeds 90 days + with pytest.raises(ValidationError, match="Date range cannot exceed 90 days"): + alpaca.trading.corporate_actions.get_announcements( + since="2024-01-01", + until="2024-06-01", # More than 90 days + ca_types=["dividend"], + ) + + # Test invalid date format + with pytest.raises(ValidationError, match="Invalid date format"): + alpaca.trading.corporate_actions.get_announcements( + since="01-01-2024", # Wrong format + until="2024-03-31", + ca_types=["dividend"], + ) + + # Test since after until + with pytest.raises( + ValidationError, match="'since' date must be before 'until'" + ): + alpaca.trading.corporate_actions.get_announcements( + since="2024-03-31", + until="2024-01-01", + ca_types=["dividend"], + ) + + def test_get_announcements_type_validation(self, alpaca): + """Test corporate action type validation.""" + with pytest.raises(ValidationError, match="Invalid corporate action type"): + alpaca.trading.corporate_actions.get_announcements( + since="2024-01-01", + until="2024-03-31", + ca_types=["invalid_type"], + ) + + @patch("py_alpaca_api.http.requests.Requests.request") + def test_get_announcement_by_id_success(self, mock_request, alpaca): + """Test successful retrieval of a specific announcement.""" + # Setup mock response + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.text = ( + '{"id": "123456", "ca_type": "dividend", "corporate_action_id": "CA123"}' + ) + mock_request.return_value = mock_response + + # Call method + result = alpaca.trading.corporate_actions.get_announcement_by_id("123456") + + # Verify + assert isinstance(result, CorporateActionModel | DividendModel) + assert result.id == "123456" + + @patch("py_alpaca_api.http.requests.Requests.request") + def test_get_announcement_by_id_not_found(self, mock_request, alpaca): + """Test handling of not found announcement.""" + # Setup mock response + mock_response = MagicMock() + mock_response.status_code = 404 + mock_response.text = "Not found" + mock_request.return_value = mock_response + + # Call method and expect error + with pytest.raises(APIRequestError, match="announcement not found"): + alpaca.trading.corporate_actions.get_announcement_by_id("invalid_id") + + def test_corporate_action_model_creation_dividend(self, mock_dividend_response): + """Test creation of DividendModel from dict.""" + model = corporate_action_class_from_dict(mock_dividend_response) + + assert isinstance(model, DividendModel) + assert model.id == "123456" + assert model.ca_type == "dividend" + assert model.cash_amount == 0.96 + assert model.dividend_type == "quarterly" + assert model.frequency == 4 + + def test_corporate_action_model_creation_split(self, mock_split_response): + """Test creation of SplitModel from dict.""" + model = corporate_action_class_from_dict(mock_split_response) + + assert isinstance(model, SplitModel) + assert model.id == "789012" + assert model.ca_type == "split" + assert model.split_from == 1.0 + assert model.split_to == 3.0 + + def test_corporate_action_model_creation_merger(self, mock_merger_response): + """Test creation of MergerModel from dict.""" + model = corporate_action_class_from_dict(mock_merger_response) + + assert isinstance(model, MergerModel) + assert model.id == "345678" + assert model.ca_type == "merger" + assert model.acquirer_symbol == "BUYER" + assert model.cash_rate == 0.5 + assert model.stock_rate == 1.2 + + def test_corporate_action_model_creation_spinoff(self, mock_spinoff_response): + """Test creation of SpinoffModel from dict.""" + model = corporate_action_class_from_dict(mock_spinoff_response) + + assert isinstance(model, SpinoffModel) + assert model.id == "901234" + assert model.ca_type == "spinoff" + assert model.new_symbol == "CHILD" + assert model.ratio == 0.25 + + def test_corporate_action_model_unknown_type(self): + """Test handling of unknown corporate action type.""" + unknown_data = { + "id": "999", + "corporate_action_id": "CA999", + "ca_type": "unknown_type", + "initiating_symbol": "TEST", + } + + model = corporate_action_class_from_dict(unknown_data) + + assert isinstance(model, CorporateActionModel) + assert not isinstance( + model, DividendModel | SplitModel | MergerModel | SpinoffModel + ) + assert model.id == "999" + assert model.ca_type == "unknown_type" + + @patch("py_alpaca_api.http.requests.Requests.request") + def test_get_all_announcements_pagination(self, mock_request, alpaca): + """Test get_all_announcements with pagination handling.""" + # Setup mock response + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.text = '{"announcements": [{"id": "1", "ca_type": "dividend", "corporate_action_id": "CA1"}]}' + mock_request.return_value = mock_response + + # Call method + result = alpaca.trading.corporate_actions.get_all_announcements( + since="2024-01-01", + until="2024-03-31", + ca_types=["dividend", "split"], + ) + + # Verify + assert isinstance(result, list) + # Note: Current implementation doesn't handle pagination fully, + # this test ensures the method works + + @pytest.mark.skipif( + not os.environ.get("ALPACA_API_KEY"), reason="API credentials not available" + ) + def test_live_api_call(self, alpaca): + """Test actual API call (requires valid credentials).""" + # Use a recent date range + today = datetime.now() + since = (today - timedelta(days=30)).strftime("%Y-%m-%d") + until = today.strftime("%Y-%m-%d") + + try: + result = alpaca.trading.corporate_actions.get_announcements( + since=since, + until=until, + ca_types=["dividend"], + ) + # Just verify it returns a list (may be empty) + assert isinstance(result, list) + except APIRequestError as e: + # API might return error for various reasons (auth, rate limit, endpoint not available, etc.) + assert e.status_code in [401, 403, 404, 429] + + def test_optional_parameters(self, alpaca): + """Test that optional parameters are handled correctly.""" + with patch.object( + alpaca.trading.corporate_actions, "get_announcements" + ) as mock_get: + mock_get.return_value = [] + + # Call with optional parameters + alpaca.trading.corporate_actions.get_announcements( + since="2024-01-01", + until="2024-03-31", + ca_types=["dividend"], + symbol="AAPL", + cusip="037833100", + date_type="ex_date", + page_limit=50, + page_token="next_page", + ) + + # Verify the method was called with all parameters + mock_get.assert_called_once() + call_args = mock_get.call_args[1] + assert call_args["symbol"] == "AAPL" + assert call_args["cusip"] == "037833100" + assert call_args["date_type"] == "ex_date" diff --git a/tests/test_trading/test_corporate_actions_live.py b/tests/test_trading/test_corporate_actions_live.py new file mode 100644 index 0000000..b9bedba --- /dev/null +++ b/tests/test_trading/test_corporate_actions_live.py @@ -0,0 +1,395 @@ +"""Integration tests for Corporate Actions API with live data. + +These tests require valid Alpaca API credentials and will make real API calls. +Run with: ./test.sh +""" + +import os +from datetime import datetime, timedelta + +import pytest + +from py_alpaca_api import PyAlpacaAPI +from py_alpaca_api.exceptions import APIRequestError, ValidationError +from py_alpaca_api.models.corporate_action_model import ( + CorporateActionModel, + DividendModel, + MergerModel, + SpinoffModel, + SplitModel, +) + +# Skip all tests if no API credentials +pytestmark = pytest.mark.skipif( + not os.environ.get("ALPACA_API_KEY") or not os.environ.get("ALPACA_SECRET_KEY"), + reason="API credentials not available", +) + + +@pytest.fixture +def alpaca(): + """Create PyAlpacaAPI instance with real credentials.""" + return PyAlpacaAPI( + api_key=os.environ.get("ALPACA_API_KEY"), + api_secret=os.environ.get("ALPACA_SECRET_KEY"), + api_paper=True, + ) + + +class TestCorporateActionsLive: + """Integration tests for Corporate Actions API with live data.""" + + def test_get_recent_dividends(self, alpaca): + """Test retrieving recent dividend announcements.""" + # Use a recent 30-day window + today = datetime.now() + since = (today - timedelta(days=30)).strftime("%Y-%m-%d") + until = today.strftime("%Y-%m-%d") + + try: + dividends = alpaca.trading.corporate_actions.get_announcements( + since=since, + until=until, + ca_types=["dividend"], + ) + + # Check the response structure + assert isinstance(dividends, list) + + # If we have dividends, verify their structure + if dividends: + dividend = dividends[0] + assert isinstance(dividend, CorporateActionModel | DividendModel) + assert hasattr(dividend, "id") + assert hasattr(dividend, "corporate_action_id") + assert hasattr(dividend, "ca_type") + assert dividend.ca_type == "dividend" + + if isinstance(dividend, DividendModel): + # Check dividend-specific fields + assert hasattr(dividend, "cash_amount") + assert hasattr(dividend, "dividend_type") + + print(f"Found {len(dividends)} dividend announcements") + for div in dividends[:5]: # Print first 5 + print( + f" {div.initiating_symbol}: ${getattr(div, 'cash_amount', div.cash)} on {div.payable_date}" + ) + + except APIRequestError as e: + # If endpoint not available (404) or auth issues, skip + if e.status_code in [404, 401, 403]: + pytest.skip(f"API endpoint not available or auth issue: {e}") + raise + + def test_get_recent_splits(self, alpaca): + """Test retrieving recent stock split announcements.""" + # Use a 60-day window for better chance of finding splits + today = datetime.now() + since = (today - timedelta(days=60)).strftime("%Y-%m-%d") + until = today.strftime("%Y-%m-%d") + + try: + splits = alpaca.trading.corporate_actions.get_announcements( + since=since, + until=until, + ca_types=["split"], + ) + + assert isinstance(splits, list) + + if splits: + split = splits[0] + assert isinstance(split, CorporateActionModel | SplitModel) + assert split.ca_type == "split" + + if isinstance(split, SplitModel): + assert hasattr(split, "split_from") + assert hasattr(split, "split_to") + + print(f"Found {len(splits)} split announcements") + for s in splits[:5]: + if hasattr(s, "split_from") and hasattr(s, "split_to"): + print(f" {s.initiating_symbol}: {s.split_from}:{s.split_to}") + + except APIRequestError as e: + if e.status_code in [404, 401, 403]: + pytest.skip(f"API endpoint not available or auth issue: {e}") + raise + + def test_get_specific_symbol_actions(self, alpaca): + """Test retrieving corporate actions for specific symbols.""" + # Use popular stocks likely to have dividends + test_symbols = ["AAPL", "MSFT", "JNJ", "KO"] + + today = datetime.now() + since = (today - timedelta(days=90)).strftime("%Y-%m-%d") + until = today.strftime("%Y-%m-%d") + + for symbol in test_symbols: + try: + actions = alpaca.trading.corporate_actions.get_announcements( + since=since, + until=until, + ca_types=["dividend", "split"], + symbol=symbol, + ) + + assert isinstance(actions, list) + + if actions: + print(f"\n{symbol} corporate actions ({len(actions)} found):") + for action in actions: + # Check that action is related to the symbol + symbols = {action.initiating_symbol, action.target_symbol} + assert symbol in symbols + print(f" Type: {action.ca_type}, Date: {action.ex_date}") + break # Found data, test successful + + except APIRequestError as e: + if e.status_code in [404, 401, 403]: + pytest.skip(f"API endpoint not available: {e}") + # Continue to next symbol if current one fails + continue + + def test_get_all_action_types(self, alpaca): + """Test retrieving all types of corporate actions.""" + today = datetime.now() + since = (today - timedelta(days=30)).strftime("%Y-%m-%d") + until = today.strftime("%Y-%m-%d") + + try: + all_actions = alpaca.trading.corporate_actions.get_announcements( + since=since, + until=until, + ca_types=["dividend", "split", "merger", "spinoff"], + ) + + assert isinstance(all_actions, list) + + # Count different types + action_counts = {} + for action in all_actions: + ca_type = action.ca_type + action_counts[ca_type] = action_counts.get(ca_type, 0) + 1 + + print("\nCorporate actions summary (last 30 days):") + print(f" Total: {len(all_actions)}") + for ca_type, count in action_counts.items(): + print(f" {ca_type.capitalize()}: {count}") + + except APIRequestError as e: + if e.status_code in [404, 401, 403]: + pytest.skip(f"API endpoint not available: {e}") + raise + + def test_date_filtering(self, alpaca): + """Test different date filtering options.""" + # Test with ex_dividend date filtering + today = datetime.now() + since = (today - timedelta(days=30)).strftime("%Y-%m-%d") + until = today.strftime("%Y-%m-%d") + + try: + # Filter by ex-dividend date + ex_date_actions = alpaca.trading.corporate_actions.get_announcements( + since=since, + until=until, + ca_types=["dividend"], + date_type="ex_date", + ) + + # Filter by payable date + payable_date_actions = alpaca.trading.corporate_actions.get_announcements( + since=since, + until=until, + ca_types=["dividend"], + date_type="payable_date", + ) + + assert isinstance(ex_date_actions, list) + assert isinstance(payable_date_actions, list) + + print("\nDate filtering results:") + print(f" Ex-dividend date filter: {len(ex_date_actions)} results") + print(f" Payable date filter: {len(payable_date_actions)} results") + + except APIRequestError as e: + if e.status_code in [404, 401, 403]: + pytest.skip(f"API endpoint not available: {e}") + raise + + def test_pagination_handling(self, alpaca): + """Test that API returns all results within date range.""" + # Note: The API currently returns all results regardless of page_limit + # This test documents the actual behavior + + # Use a shorter date range to limit results + today = datetime.now() + since = (today - timedelta(days=7)).strftime("%Y-%m-%d") + until = today.strftime("%Y-%m-%d") + + try: + # Request with page_limit parameter (API ignores it but we include it) + results_with_limit = alpaca.trading.corporate_actions.get_announcements( + since=since, + until=until, + ca_types=["dividend"], + page_limit=10, + ) + + # Request without page_limit + results_without_limit = alpaca.trading.corporate_actions.get_announcements( + since=since, + until=until, + ca_types=["dividend"], + ) + + assert isinstance(results_with_limit, list) + assert isinstance(results_without_limit, list) + + # API returns all results regardless of page_limit + assert len(results_with_limit) == len(results_without_limit) + + print("\nPagination behavior test:") + print(f" Results with page_limit=10: {len(results_with_limit)}") + print(f" Results without page_limit: {len(results_without_limit)}") + print(" Note: API currently returns all results within date range") + + except APIRequestError as e: + if e.status_code in [404, 401, 403]: + pytest.skip(f"API endpoint not available: {e}") + raise + + def test_get_announcement_by_id(self, alpaca): + """Test retrieving a specific announcement by ID.""" + # First, get some announcements to have valid IDs + today = datetime.now() + since = (today - timedelta(days=30)).strftime("%Y-%m-%d") + until = today.strftime("%Y-%m-%d") + + try: + announcements = alpaca.trading.corporate_actions.get_announcements( + since=since, + until=until, + ca_types=["dividend", "split"], + page_limit=5, + ) + + if not announcements: + pytest.skip("No announcements found to test get_by_id") + + # Get the first announcement's ID + announcement_id = announcements[0].id + + # Retrieve it by ID + single_announcement = ( + alpaca.trading.corporate_actions.get_announcement_by_id(announcement_id) + ) + + assert isinstance(single_announcement, CorporateActionModel) + assert single_announcement.id == announcement_id + assert single_announcement.ca_type in [ + "dividend", + "split", + "merger", + "spinoff", + ] + + print("\nRetrieved announcement by ID:") + print(f" ID: {single_announcement.id}") + print(f" Type: {single_announcement.ca_type}") + print(f" Symbol: {single_announcement.initiating_symbol}") + + except APIRequestError as e: + if e.status_code in [404, 401, 403]: + pytest.skip(f"API endpoint not available: {e}") + raise + + def test_error_handling(self, alpaca): + """Test error handling for invalid requests.""" + # Test with invalid announcement ID + with pytest.raises(APIRequestError) as exc_info: + alpaca.trading.corporate_actions.get_announcement_by_id("invalid_id_12345") + + # Should get 404 for non-existent ID + assert exc_info.value.status_code == 404 + + # Test with date range exceeding 90 days + with pytest.raises(ValidationError, match="Date range cannot exceed 90 days"): + alpaca.trading.corporate_actions.get_announcements( + since="2024-01-01", + until="2024-06-01", + ca_types=["dividend"], + ) + + # Test with invalid date format + with pytest.raises(ValidationError, match="Invalid date format"): + alpaca.trading.corporate_actions.get_announcements( + since="01/01/2024", + until="01/31/2024", + ca_types=["dividend"], + ) + + # Test with invalid corporate action type + with pytest.raises(ValidationError, match="Invalid corporate action type"): + alpaca.trading.corporate_actions.get_announcements( + since="2024-01-01", + until="2024-01-31", + ca_types=["invalid_type"], + ) + + def test_mergers_and_spinoffs(self, alpaca): + """Test retrieving merger and spinoff announcements.""" + # Use wider date range as these are less common + today = datetime.now() + since = (today - timedelta(days=90)).strftime("%Y-%m-%d") + until = today.strftime("%Y-%m-%d") + + try: + # Get mergers + mergers = alpaca.trading.corporate_actions.get_announcements( + since=since, + until=until, + ca_types=["merger"], + ) + + # Get spinoffs + spinoffs = alpaca.trading.corporate_actions.get_announcements( + since=since, + until=until, + ca_types=["spinoff"], + ) + + assert isinstance(mergers, list) + assert isinstance(spinoffs, list) + + print("\nMergers and Spinoffs (last 90 days):") + print(f" Mergers: {len(mergers)}") + print(f" Spinoffs: {len(spinoffs)}") + + if mergers: + merger = mergers[0] + if isinstance(merger, MergerModel): + assert hasattr(merger, "acquirer_symbol") + print( + f" Example merger: {merger.target_symbol} acquired by {merger.acquirer_symbol}" + ) + + if spinoffs: + spinoff = spinoffs[0] + if isinstance(spinoff, SpinoffModel): + assert hasattr(spinoff, "new_symbol") + print( + f" Example spinoff: {spinoff.initiating_symbol} spinning off {spinoff.new_symbol}" + ) + + except APIRequestError as e: + if e.status_code in [404, 401, 403]: + pytest.skip(f"API endpoint not available: {e}") + raise + + +if __name__ == "__main__": + # Allow running this file directly for testing + pytest.main([__file__, "-v"]) diff --git a/tests/test_trading/test_order_enhancements.py b/tests/test_trading/test_order_enhancements.py new file mode 100644 index 0000000..4015eb6 --- /dev/null +++ b/tests/test_trading/test_order_enhancements.py @@ -0,0 +1,348 @@ +import json +from unittest.mock import MagicMock, patch + +import pytest + +from py_alpaca_api.exceptions import ValidationError +from py_alpaca_api.models.order_model import OrderModel +from py_alpaca_api.trading.orders import Orders + + +class TestOrderEnhancements: + @pytest.fixture + def orders(self): + base_url = "https://paper-api.alpaca.markets/v2" + headers = { + "APCA-API-KEY-ID": "test_key", + "APCA-API-SECRET-KEY": "test_secret", + } + return Orders(base_url=base_url, headers=headers) + + @pytest.fixture + def mock_order_response(self): + return { + "id": "order-123", + "client_order_id": "client-123", + "created_at": "2024-01-15T10:00:00Z", + "updated_at": "2024-01-15T10:00:00Z", + "submitted_at": "2024-01-15T10:00:00Z", + "filled_at": None, + "expired_at": None, + "canceled_at": None, + "failed_at": None, + "replaced_at": None, + "replaced_by": None, + "replaces": None, + "asset_id": "asset-123", + "symbol": "AAPL", + "asset_class": "us_equity", + "notional": None, + "qty": "10", + "filled_qty": "0", + "filled_avg_price": None, + "order_class": "simple", + "order_type": "market", + "type": "market", + "side": "buy", + "time_in_force": "day", + "limit_price": None, + "stop_price": None, + "status": "new", + "extended_hours": False, + "legs": None, + "trail_percent": None, + "trail_price": None, + "hwm": None, + "subtag": None, + "source": None, + } + + def test_replace_order(self, orders, mock_order_response): + with patch("py_alpaca_api.trading.orders.Requests") as mock_requests: + mock_response = MagicMock() + mock_response.text = json.dumps(mock_order_response) + mock_requests.return_value.request.return_value = mock_response + + result = orders.replace_order( + order_id="order-123", + qty=20, + limit_price=150.00, + time_in_force="gtc", + client_order_id="new-client-123", + ) + + assert isinstance(result, OrderModel) + assert result.id == "order-123" + assert result.symbol == "AAPL" + + # Verify the API call + mock_requests.return_value.request.assert_called_once() + call_args = mock_requests.return_value.request.call_args + assert call_args.kwargs["method"] == "PATCH" + assert "order-123" in call_args.kwargs["url"] + assert call_args.kwargs["json"]["qty"] == 20 + assert call_args.kwargs["json"]["limit_price"] == 150.00 + assert call_args.kwargs["json"]["time_in_force"] == "gtc" + assert call_args.kwargs["json"]["client_order_id"] == "new-client-123" + + def test_replace_order_no_params(self, orders): + with pytest.raises(ValidationError, match="At least one parameter"): + orders.replace_order(order_id="order-123") + + def test_get_by_client_order_id(self, orders, mock_order_response): + with patch("py_alpaca_api.trading.orders.Requests") as mock_requests: + mock_response = MagicMock() + # Return a list of orders for the filtering to work + mock_response.text = json.dumps([mock_order_response]) + mock_requests.return_value.request.return_value = mock_response + + result = orders.get_by_client_order_id("client-123") + + assert isinstance(result, OrderModel) + assert result.client_order_id == "client-123" + + # Verify the API call - it should query all orders + mock_requests.return_value.request.assert_called_once_with( + method="GET", + url=f"{orders.base_url}/orders", + headers=orders.headers, + params={"status": "all", "limit": 500}, + ) + + def test_cancel_by_client_order_id(self, orders, mock_order_response): + with patch("py_alpaca_api.trading.orders.Requests") as mock_requests: + # First call: get_by_client_order_id to find the order + get_response = MagicMock() + get_response.text = json.dumps([mock_order_response]) + # Second call: cancel_by_id + cancel_response = MagicMock() + cancel_response.text = "{}" + + mock_requests.return_value.request.side_effect = [ + get_response, + cancel_response, + ] + + result = orders.cancel_by_client_order_id("client-123") + + assert "cancelled" in result + + # Verify the API calls + assert mock_requests.return_value.request.call_count == 2 + # First call should be to get all orders + first_call = mock_requests.return_value.request.call_args_list[0] + assert first_call.kwargs["method"] == "GET" + assert first_call.kwargs["url"] == f"{orders.base_url}/orders" + # Second call should be to cancel by ID + second_call = mock_requests.return_value.request.call_args_list[1] + assert second_call.kwargs["method"] == "DELETE" + assert "order-123" in second_call.kwargs["url"] + + def test_market_order_with_client_id(self, orders, mock_order_response): + mock_order_response["client_order_id"] = "my-custom-id" + + with patch("py_alpaca_api.trading.orders.Requests") as mock_requests: + mock_response = MagicMock() + mock_response.text = json.dumps(mock_order_response) + mock_requests.return_value.request.return_value = mock_response + + result = orders.market( + symbol="AAPL", + qty=10, + side="buy", + client_order_id="my-custom-id", + ) + + assert isinstance(result, OrderModel) + assert result.client_order_id == "my-custom-id" + + # Verify the API call includes client_order_id + call_args = mock_requests.return_value.request.call_args + assert call_args.kwargs["json"]["client_order_id"] == "my-custom-id" + + def test_market_order_with_order_class(self, orders, mock_order_response): + mock_order_response["order_class"] = "oto" + + with patch("py_alpaca_api.trading.orders.Requests") as mock_requests: + mock_response = MagicMock() + mock_response.text = json.dumps(mock_order_response) + mock_requests.return_value.request.return_value = mock_response + + result = orders.market( + symbol="AAPL", + qty=10, + side="buy", + order_class="oto", + ) + + assert isinstance(result, OrderModel) + assert result.order_class == "oto" + + # Verify the API call includes order_class + call_args = mock_requests.return_value.request.call_args + assert call_args.kwargs["json"]["order_class"] == "oto" + + def test_limit_order_with_enhancements(self, orders, mock_order_response): + mock_order_response["order_class"] = "oco" + mock_order_response["client_order_id"] = "limit-custom-id" + mock_order_response["extended_hours"] = True + + with patch("py_alpaca_api.trading.orders.Requests") as mock_requests: + mock_response = MagicMock() + mock_response.text = json.dumps(mock_order_response) + mock_requests.return_value.request.return_value = mock_response + + result = orders.limit( + symbol="AAPL", + limit_price=150.00, + qty=10, + side="buy", + extended_hours=True, + client_order_id="limit-custom-id", + order_class="oco", + ) + + assert isinstance(result, OrderModel) + assert result.order_class == "oco" + assert result.client_order_id == "limit-custom-id" + assert result.extended_hours is True + + # Verify the API call + call_args = mock_requests.return_value.request.call_args + assert call_args.kwargs["json"]["order_class"] == "oco" + assert call_args.kwargs["json"]["client_order_id"] == "limit-custom-id" + assert call_args.kwargs["json"]["extended_hours"] is True + + def test_stop_order_with_enhancements(self, orders, mock_order_response): + with patch("py_alpaca_api.trading.orders.Requests") as mock_requests: + mock_response = MagicMock() + mock_response.text = json.dumps(mock_order_response) + mock_requests.return_value.request.return_value = mock_response + + result = orders.stop( + symbol="AAPL", + stop_price=145.00, + qty=10, + side="sell", + client_order_id="stop-custom-id", + order_class="simple", + ) + + assert isinstance(result, OrderModel) + + # Verify the API call + call_args = mock_requests.return_value.request.call_args + assert call_args.kwargs["json"]["client_order_id"] == "stop-custom-id" + assert call_args.kwargs["json"]["order_class"] == "simple" + + def test_stop_limit_order_with_enhancements(self, orders, mock_order_response): + with patch("py_alpaca_api.trading.orders.Requests") as mock_requests: + mock_response = MagicMock() + mock_response.text = json.dumps(mock_order_response) + mock_requests.return_value.request.return_value = mock_response + + result = orders.stop_limit( + symbol="AAPL", + stop_price=145.00, + limit_price=144.50, + qty=10, + side="sell", + client_order_id="stop-limit-custom-id", + order_class="simple", + ) + + assert isinstance(result, OrderModel) + + # Verify the API call + call_args = mock_requests.return_value.request.call_args + assert call_args.kwargs["json"]["client_order_id"] == "stop-limit-custom-id" + assert call_args.kwargs["json"]["order_class"] == "simple" + + def test_trailing_stop_order_with_enhancements(self, orders, mock_order_response): + with patch("py_alpaca_api.trading.orders.Requests") as mock_requests: + mock_response = MagicMock() + mock_response.text = json.dumps(mock_order_response) + mock_requests.return_value.request.return_value = mock_response + + result = orders.trailing_stop( + symbol="AAPL", + qty=10, + trail_percent=2.5, + side="sell", + client_order_id="trail-custom-id", + order_class="simple", + ) + + assert isinstance(result, OrderModel) + + # Verify the API call + call_args = mock_requests.return_value.request.call_args + assert call_args.kwargs["json"]["client_order_id"] == "trail-custom-id" + assert call_args.kwargs["json"]["order_class"] == "simple" + + def test_order_class_priority(self, orders, mock_order_response): + """Test that explicit order_class overrides bracket detection.""" + with patch("py_alpaca_api.trading.orders.Requests") as mock_requests: + mock_response = MagicMock() + mock_response.text = json.dumps(mock_order_response) + mock_requests.return_value.request.return_value = mock_response + + # With both take_profit and stop_loss, but explicit order_class should be used + orders.market( + symbol="AAPL", + qty=10, + take_profit=160.00, + stop_loss=140.00, # Add stop_loss to avoid validation error + order_class="oco", # Explicitly set to oco + ) + + # Verify the API call uses oco, not bracket + call_args = mock_requests.return_value.request.call_args + assert call_args.kwargs["json"]["order_class"] == "oco" + + def test_extended_hours_all_order_types(self, orders, mock_order_response): + """Test that extended_hours parameter works for all order types.""" + with patch("py_alpaca_api.trading.orders.Requests") as mock_requests: + mock_response = MagicMock() + mock_response.text = json.dumps(mock_order_response) + mock_requests.return_value.request.return_value = mock_response + + # Test market order + orders.market(symbol="AAPL", qty=10, extended_hours=True) + call_args = mock_requests.return_value.request.call_args + assert call_args.kwargs["json"]["extended_hours"] is True + + # Test limit order + orders.limit(symbol="AAPL", limit_price=150.00, qty=10, extended_hours=True) + call_args = mock_requests.return_value.request.call_args + assert call_args.kwargs["json"]["extended_hours"] is True + + # Test stop order + orders.stop(symbol="AAPL", stop_price=145.00, qty=10, extended_hours=True) + call_args = mock_requests.return_value.request.call_args + assert call_args.kwargs["json"]["extended_hours"] is True + + def test_replace_order_partial_update(self, orders, mock_order_response): + """Test that replace_order can update individual fields.""" + with patch("py_alpaca_api.trading.orders.Requests") as mock_requests: + mock_response = MagicMock() + mock_response.text = json.dumps(mock_order_response) + mock_requests.return_value.request.return_value = mock_response + + # Only update quantity + orders.replace_order(order_id="order-123", qty=50) + + call_args = mock_requests.return_value.request.call_args + body = call_args.kwargs["json"] + assert body["qty"] == 50 + assert "limit_price" not in body + assert "stop_price" not in body + assert "time_in_force" not in body + + # Only update time_in_force + orders.replace_order(order_id="order-123", time_in_force="ioc") + + call_args = mock_requests.return_value.request.call_args + body = call_args.kwargs["json"] + assert body["time_in_force"] == "ioc" + assert "qty" not in body