diff --git a/.github/workflows/codeql-analysis.yaml b/.github/workflows/codeql-analysis.yaml index 09bb853..d5bbf7e 100644 --- a/.github/workflows/codeql-analysis.yaml +++ b/.github/workflows/codeql-analysis.yaml @@ -1,10 +1,12 @@ name: "CodeQL" on: - push: - branches: [ main ] pull_request: - branches: [ main ] + branches: + - main + push: + branches: + - main schedule: - cron: '0 0 * * 1' @@ -19,16 +21,16 @@ jobs: language: [ 'python' ] steps: - name: Checkout repository - uses: actions/checkout@v2 + uses: actions/checkout@v6 # Initializes the CodeQL tools for scanning. - name: Initialize CodeQL - uses: github/codeql-action/init@v1 + uses: github/codeql-action/init@v4 with: languages: ${{ matrix.language }} - name: Autobuild - uses: github/codeql-action/autobuild@v1 + uses: github/codeql-action/autobuild@v4 - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@v1 + uses: github/codeql-action/analyze@v4 diff --git a/.github/workflows/docker-publish.yaml b/.github/workflows/docker-publish.yaml index da2779b..35ffc43 100644 --- a/.github/workflows/docker-publish.yaml +++ b/.github/workflows/docker-publish.yaml @@ -25,18 +25,15 @@ jobs: steps: - name: Checkout code - uses: nschloe/action-cached-lfs-checkout@v1.1.3 + uses: nschloe/action-cached-lfs-checkout@v1.2.3 with: - exclude: "batbot/*/models/pytorch/" + exclude: "examples/example[2-4].wav" - - uses: docker/setup-qemu-action@v1 + - uses: docker/setup-qemu-action@v3 name: Set up QEMU id: qemu - with: - image: tonistiigi/binfmt:latest - platforms: all - - uses: docker/setup-buildx-action@v1 + - uses: docker/setup-buildx-action@v3 name: Set up Docker Buildx id: buildx @@ -45,35 +42,43 @@ jobs: # Log into container registries - name: Login to DockerHub - uses: docker/login-action@v1 + uses: docker/login-action@v3 with: - username: batbot - password: ${{ secrets.BATBOT_DOCKER_HUB_TOKEN }} + username: ${{ vars.BATBOT_DOCKERHUB_USERNAME }} + password: ${{ secrets.BATBOT_DOCKERHUB_TOKEN }} - # Push tagged image (version tag + latest) to registries - - name: Tagged Docker Hub - if: ${{ github.event_name == 'push' && startsWith(github.event.ref, 'refs/tags/v') }} - run: | - VERSION=$(echo ${GITHUB_REF} | sed 's#.*/v##') - echo "IMAGE_TAG=${VERSION}" >> $GITHUB_ENV + - name: Login to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} - # Push bleeding-edge image (main tag) to registries - - name: Bleeding Edge Docker Hub - if: github.ref == 'refs/heads/main' + # Push bleeding-edge image ("" tag) to registries + - name: Bleeding Edge Docker Hub (Default Option) run: | - echo "IMAGE_TAG=main" >> $GITHUB_ENV + TAG=$(echo ${GITHUB_REF_NAME} | sed 's/\//-/') + echo "IMAGE_TAG=${TAG}" >> $GITHUB_ENV - # Push nightly image (nightly tag) to registries + # Push nightly image ("nightly" tag) to registries - name: Nightly Docker Hub if: github.event_name == 'schedule' run: | echo "IMAGE_TAG=nightly" >> $GITHUB_ENV + # Push tagged image ("" tag) to registries + - name: Tagged Docker Hub + if: ${{ github.event_name == 'push' && startsWith(github.event.ref, 'refs/tags/v') }} + run: | + VERSION=$(echo ${GITHUB_REF} | sed 's#.*/v##') + echo "IMAGE_TAG=${VERSION}" >> $GITHUB_ENV + # Build images - name: Build Batbot run: | docker buildx build \ -t kitware/batbot:${{ env.IMAGE_TAG }} \ + -t ghcr.io/kitware/batbot:${{ env.IMAGE_TAG }} \ --platform linux/amd64 \ --push \ . @@ -84,6 +89,7 @@ jobs: run: | docker buildx build \ -t kitware/batbot:latest \ + -t ghcr.io/kitware/batbot:latest \ --platform linux/amd64 \ --push \ . diff --git a/.github/workflows/python-publish.yaml b/.github/workflows/python-publish.yaml index ba70428..2a0d59d 100644 --- a/.github/workflows/python-publish.yaml +++ b/.github/workflows/python-publish.yaml @@ -4,6 +4,7 @@ on: push: branches: - main + - jrp/ci-cd-fixes jobs: @@ -14,16 +15,16 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest, macos-latest] - python-version: [3.8] + python-version: [3.12] steps: - name: Checkout code - uses: nschloe/action-cached-lfs-checkout@v1.1.3 + uses: nschloe/action-cached-lfs-checkout@v1.2.3 with: - exclude: "batbot/*/models/pytorch/" + exclude: "examples/example[2-4].wav" - - uses: actions/setup-python@v2 - name: Install Python + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v6 with: python-version: ${{ matrix.python-version }} @@ -33,8 +34,9 @@ jobs: pip install build python -m build --wheel --outdir dist/ . - - uses: actions/upload-artifact@v2 + - uses: actions/upload-artifact@v6 with: + name: artifact-wheel-${{ matrix.os }}-${{ matrix.python-version }} path: ./dist/*.whl build_sdist: @@ -42,14 +44,14 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout code - uses: nschloe/action-cached-lfs-checkout@v1.1.3 + uses: nschloe/action-cached-lfs-checkout@v1.2.3 with: - exclude: "batbot/*/models/pytorch/" + exclude: "examples/example[2-4].wav" - - uses: actions/setup-python@v2 - name: Install Python + - name: Set up Python 3.12 + uses: actions/setup-python@v6 with: - python-version: '3.8' + python-version: '3.12' - name: Build sdist run: | @@ -57,57 +59,68 @@ jobs: pip install build python -m build --sdist --outdir dist/ . - - uses: actions/upload-artifact@v2 + - uses: actions/upload-artifact@v6 with: + name: artifact-sdist path: ./dist/*.tar.gz test_wheel: needs: [build_wheels, build_sdist] runs-on: ubuntu-latest - env: - CLASSIFIER_BATCH_SIZE: 16 # test wheel if: github.event_name == 'push' steps: - - uses: actions/setup-python@v2 - name: Install Python + - name: Set up Python 3.12 + uses: actions/setup-python@v6 with: - python-version: '3.8' + python-version: '3.12' - - uses: actions/download-artifact@v2 + - uses: actions/download-artifact@v7 with: - name: artifact - path: dist + path: artifact - name: Install wheel run: | pip install --upgrade pip pip install wheel + find . + mkdir dist + cp artifact/*-ubuntu-*/*.whl dist/ + cp artifact/*/*.tar.gz dist/ pip install dist/*.whl - name: Test module run: | - python -c "import batbot; batbot.fetch(); batbot.example();" + python -c "import batbot;" - - name: Test CLI - run: | - batbot fetch - batbot example + # - name: Test CLI + # run: | + # batbot example upload_pypi: needs: [test_wheel] runs-on: ubuntu-latest # upload to PyPI on every tag starting with 'v' if: github.event_name == 'push' && startsWith(github.event.ref, 'refs/tags/v') + environment: + name: pypi + url: https://pypi.org/p/batbot + permissions: + id-token: write steps: - - uses: actions/download-artifact@v2 + - uses: actions/download-artifact@v7 with: - name: artifact - path: dist + path: artifact + + - name: Install wheel + run: | + find . + mkdir dist + cp artifact/*-ubuntu-*/*.whl dist/ + cp artifact/*/*.tar.gz dist/ - - uses: pypa/gh-action-pypi-publish@release/v1 + - name: Publish package distributions to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 with: - user: __token__ - password: ${{ secrets.PYPI_PASSWORD }} - # To test: repository_url: https://test.pypi.org/legacy/ + password: ${{ secrets.BATBOT_PYPI_TOKEN }} diff --git a/.github/workflows/testing.yaml b/.github/workflows/testing.yaml index 6a5dbd9..a524773 100644 --- a/.github/workflows/testing.yaml +++ b/.github/workflows/testing.yaml @@ -12,51 +12,46 @@ jobs: fail-fast: false matrix: # Use the same Python version used the Dockerfile - python-version: [3.9] + python-version: [3.12] env: OS: ubuntu-latest PYTHON: ${{ matrix.python-version }} - CLASSIFIER_BATCH_SIZE: 16 steps: # Checkout and env setup - name: Checkout code - uses: nschloe/action-cached-lfs-checkout@v1.1.3 + uses: nschloe/action-cached-lfs-checkout@v1.2.3 with: - exclude: "batbot/*/models/pytorch/" + exclude: "examples/example[3-4].wav" - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v6 with: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | python -m pip install --upgrade pip - pip install -r requirements.txt - pip install -r requirements.optional.txt + pip install -r requirements/runtime.txt + pip install -r requirements/optional.txt + pip install -e . - - name: Lint with flake8 + - name: Check with pre-commit run: | - # stop the build if there are Python syntax errors or undefined names - flake8 . --count --show-source --statistics - # exit-zero treats all errors as warnings. - flake8 . --count --exit-zero --max-complexity=10 --statistics + SKIP=hadolint pre-commit + SKIP=hadolint pre-commit run --all-files - - name: Run tests + - name: Run tests and coverage run: | set -ex - pytest --cov=./ --cov-append --random-order-seed=1 - - - name: Run coverage - run: | - coverage xml + pytest --cov=batbot --cov-append --cov-report=xml --random-order-seed=1 - name: Upload coverage to Codecov - uses: codecov/codecov-action@v1.2.1 + continue-on-error: true + uses: codecov/codecov-action@v5 with: - token: ${{ secrets.CODECOV_TOKEN }} + token: ${{ secrets.BATBOT_CODECOV_TOKEN }} files: ./coverage/coverage.xml env_vars: OS,PYTHON fail_ci_if_error: true diff --git a/.gitignore b/.gitignore index 27845ac..fde0cdc 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,8 @@ output.*.jpg *.log* +build/ +dist/ *.egg-info/ .coverage* @@ -16,3 +18,4 @@ docs/_build/ example*.jpg example*.json +.vscode/* diff --git a/Dockerfile b/Dockerfile index b486449..7e57c27 100644 --- a/Dockerfile +++ b/Dockerfile @@ -31,14 +31,11 @@ RUN python3 -m venv /venv # hadolint ignore=DL3003,DL3013 RUN /venv/bin/pip install --no-cache-dir -r requirements/runtime.txt \ && /venv/bin/pip install --no-cache-dir -r requirements/optional.txt \ - && cd tpl/pyastar2d/ \ - && /venv/bin/pip install --no-cache-dir -e . \ - && cd ../.. \ - && /venv/bin/pip install --no-cache-dir -e . \ - && if [ "$(uname -m)" != "aarch64" ] \ - ; then \ - /venv/bin/pip uninstall -y onnxruntime \ - /venv/bin/pip install --no-cache-dir onnxruntime-gpu \ - ; fi + && /venv/bin/pip install --no-cache-dir -e . + # && if [ "$(uname -m)" != "aarch64" ] \ + # ; then \ + # /venv/bin/pip uninstall -y onnxruntime \ + # /venv/bin/pip install --no-cache-dir onnxruntime-gpu \ + # ; fi CMD [".", "/venv/bin/activate", "&&", "exec", "python", "app.py"] diff --git a/ISSUES.rst b/ISSUES.rst index 4838b22..078a561 100644 --- a/ISSUES.rst +++ b/ISSUES.rst @@ -2,4 +2,9 @@ Known Issues ============ -N/A +TODO: +- Fix CI/CD docker build +- Add API documentation, tutorials, and examples +- Create example notebooks with Google colab +- Crate discord / discourse community board +- [BatAI] Upload training scripts with MLFlow support, database export with celery diff --git a/MANIFEST.in b/MANIFEST.in index 3a2e5b8..e38272b 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,5 +1,3 @@ -include pyproject.toml - # Include the README and SECURITY documents include *.rst diff --git a/README.rst b/README.rst index 6461011..f5513c2 100644 --- a/README.rst +++ b/README.rst @@ -4,11 +4,104 @@ Kitware BatBot |Tests| |Codecov| |Wheel| |Docker| |ReadTheDocs| |Huggingface| +.. image:: https://github.com/Kitware/batbot/raw/jrp/ci-cd-fixes/assets/logo.png + :alt: Batbot + :align: center + .. contents:: Quick Links :backlinks: none .. sectnum:: +Development Environment +----------------------- + +.. code-block:: bash + + # Find repo on host machine + cd ~/code/batbot + + # Build Docker image + docker build -t kitware/batbot:latest . + + # Start Docker container using image + docker run \ + -it \ + --rm \ + --entrypoint bash \ + --name batbot \ + -v $(pwd):/code \ + kitware/batbot:latest + + ######################## + # Inside the container # + ######################## + + # Activate Python environment + source /venv/bin/activate + + # Install local version + pip install -e . + + # Run batbot + batbot --help + +Spectrogram Extraction +---------------------- + +Here are the steps for extracting the compressed spectrogram: + +* Create the STFT + + * Load the original waveform at the original sample rate + * Resample waveform to 250kHz + * Convert to a STFT spectrogram (fft=512, method=blackmanharris, window=256, hop=16) + * Convert complex power STFT to amplitude STFT (dB) + +* Normalize the STFT + + * Trim STFT to minimum and maximum frequencies (5kHz to 120kHz) + * Subtract the per-freqency median dB (reduce any spectral bias / shift) + * Set global dynamic range to -80 dB from the global maximum amplitude + * Calculate the global median non-minimum dB (greater than -80dB) + * Calculate the median absolute deviation (MAD) + * Autogain the dynamic range to (5 * MAD) below the global amplitude median, if necessary + +* Quantize the STFT + + * Quantize the floating-point amplitude STFT to a 16-bit integer representation spanning the full dynamic range (65,536 bins) + * Vertically flip the spectrogram (low frequencies on bottom) and convert to a C-contiguous array + +* Find Candidate Chirps + + * Create a 12ms sliding window with a 3ms stride + * Keep the time windows that show a substantial right-skew across 10% of the frequency range + * Add any user-provided time windows (annotations) to the found candidates windows + * Merge any overlapping time windows into a set of contiguous time ranges + * Tighten the candidate time ranges (and separate as needed) by repeating the same skew-based filter with a smaller sliding window and stride + +* Extract Chirp Metrics + + * *for each candidate chirp* + * *Start*: First, find the peak amplitude location. + * Step 1 - Normalize the chirp to the full 16-bit range. Calculate a histogram and identify the most common dB and standard deviation. Scale the amplitude values using an inverted PDF, weighting each value by its inverse probability of being noise (values below the most common dB are set to zero) + * Step 2 - Apply a median filter and re-normalize + * Step 3 - Apply a morphological open operation + * Step 4 - Blur the chirp (k=5) and re-normalize + * Step 5 - Find contours using the "marching squares" algorithm and select the one that contains the peak amplitude. Extract the convex hull of the contour and smooth the resulting outline + * Step 6 - Extract a segmentation mask for the contour + * Step 7 - Locate the harmonic (doubling the frequency) and echo (right edge of the contour to the end of the chirp time range) regions. Remove any overlapping noise from the chirp contour. + * Step 8 - Locate the start, end, and characteristic frequency points (peak amplitude) and calculate an optimization cost grid for the contour using the masked amplitudes. + * Step 9 - Solve a minimum distance optimization using A* that also maximizes the amplutide values from start to end points. + * Step 10 - Smooth the contour path, extract the contour's slope, then identify the knee, heel, and other defining attributes. + * *End*: Finally, if any of the above steps fails, or the chirp's attributes do not make semantic sense, then skip the candidate chirp. + +* Create Output + + * Collect all valid chirps regions and metadata, create a compressed spectrogram + * Write the 16-bit spectrogram as a series of 8-bit JPEGs image chunks (max width per chunk 50k pixels) + * Write the file and chirp metadata to a JSON file. + How to Install -------------- @@ -200,16 +293,16 @@ Reference `pre-commit's installation instructions `_. Furthermore, try to conform to ``PEP8``. You should set up your preferred editor to use ``flake8`` as its Python linter, but pre-commit will ensure compliance before a git commit is completed. This will use the ``flake8`` configuration within ``setup.cfg``, which ignores several errors and stylistic considerations. See the ``setup.cfg`` file for a full and accurate listing of stylistic codes to ignore. -.. |Tests| image:: https://github.com/Kitware/batbot/actions/workflows/testing.yml/badge.svg?branch=main - :target: https://github.com/Kitware/batbot/actions/workflows/testing.yml +.. |Tests| image:: https://github.com/Kitware/batbot/actions/workflows/testing.yaml/badge.svg?branch=main + :target: https://github.com/Kitware/batbot/actions/workflows/testing.yaml :alt: GitHub CI .. |Codecov| image:: https://codecov.io/gh/Kitware/batbot/branch/main/graph/badge.svg?token=FR6ITMWQNI :target: https://app.codecov.io/gh/Kitware/batbot :alt: Codecov -.. |Wheel| image:: https://github.com/Kitware/batbot/actions/workflows/python-publish.yml/badge.svg - :target: https://github.com/Kitware/batbot/actions/workflows/python-publish.yml +.. |Wheel| image:: https://github.com/Kitware/batbot/actions/workflows/python-publish.yaml/badge.svg + :target: https://github.com/Kitware/batbot/actions/workflows/python-publish.yaml :alt: Python Wheel .. |Docker| image:: https://img.shields.io/docker/image-size/kitware/batbot/latest diff --git a/assets/logo.png b/assets/logo.png new file mode 100644 index 0000000..4501cb8 Binary files /dev/null and b/assets/logo.png differ diff --git a/batbot/__init__.py b/batbot/__init__.py index 9180fb8..111778f 100644 --- a/batbot/__init__.py +++ b/batbot/__init__.py @@ -94,7 +94,8 @@ def pipeline( """ # Generate spectrogram output_paths, metadata_path, metadata = spectrogram.compute(filepath) - raise NotImplementedError + + return output_paths, metadata_path def batch( diff --git a/batbot/spectrogram/__init__.py b/batbot/spectrogram/__init__.py index 074030a..baf5537 100644 --- a/batbot/spectrogram/__init__.py +++ b/batbot/spectrogram/__init__.py @@ -68,6 +68,8 @@ def get_slope_islands(slope_flags): def merge_ranges(ranges, max_val): merged = [] merge = [] + # sort by range start times in case ranges are out of order + ranges = sorted(ranges) for values in ranges: values = list(values) if len(merge) == 0: @@ -114,9 +116,6 @@ def plot_histogram( output_path='.', output_filename='histogram.png', ): - if output_path is None: - return - if max_val is None: max_val = int(image.max()) @@ -139,12 +138,15 @@ def plot_histogram( hist = gaussian_filter1d(hist, smoothing, mode='nearest') hist_original = (hist_original / hist_original.max()) * hist.max() - hist_ = np.argmax(hist) - hist_std = np.abs(image - hist_).mean() + mode_ = np.argmax(hist) # histogram mode csum = np.cumsum(hist) / hist.sum() csum_ = np.where(csum >= csum_threshold)[0].min() + retval = med_, std_, mode_ + if output_path is None: + return retval + y_max = hist.max() * 1.01 # Plot the histogram plt.figure(figsize=(7, 7)) @@ -157,15 +159,15 @@ def plot_histogram( plt.axhline(0, color='black') plt.plot(hist_original, label='Histogram Raw (Non-zero)', color='orange', alpha=0.8) plt.plot(hist, label='Histogram Smoothed (Non-zero)') - plt.plot(csum * y_max, label='Cumulitive Sum') + plt.plot(csum * y_max, label='Cumulative Sum') plt.plot([mean_] * 2, [0, y_max], color='black', linestyle='--', label=f'Mean ({mean_:0.01f})') plt.plot([med_] * 2, [0, y_max], color='red', linestyle='--', label=f'Median ({med_:0.01f})') plt.plot( - [hist_] * 2, + [mode_] * 2, [0, y_max], color='grey', linestyle='--', - label=f'Histogram Peak ({hist_:0.01f})', + label=f'Histogram Peak ({mode_:0.01f} +/- {std_:0.01f})', ) plt.plot( [csum_] * 2, @@ -174,7 +176,7 @@ def plot_histogram( linestyle='--', label=f'CSUM >= {csum_threshold:0.02f} ({csum_:0.01f})', ) - plt.axvspan(hist_ - hist_std, hist_ + hist_std, color='grey', alpha=0.1) + plt.axvspan(mode_ - std_, mode_ + std_, color='grey', alpha=0.1) # plt.plot([med_ - std_] * 2, [0, hist.max()], color='blue', linestyle='--', label=f'Median +/- MAD [{med_ - std_:0.01f} - {med_ + std_:0.01f}]') # plt.plot([med_ + std_] * 2, [0, hist.max()], color='blue', linestyle='--') @@ -193,7 +195,7 @@ def plot_histogram( ) plt.close('all') - return med_, std_, (hist_, hist_std) + return retval def generate_waveplot( @@ -257,26 +259,32 @@ def load_stft( ) # Convert the complex power (amplitude + phase) into amplitude (decibels) stft_db = librosa.power_to_db(np.abs(stft) ** 2, ref=np.max) + # Retrieve time vector in seconds corresponding to STFT + time_vec = librosa.frames_to_time( + range(stft_db.shape[1]), sr=sr, hop_length=hop_length, n_fft=n_fft + ) # Remove frequencies that we do not need [FREQ_MIN - FREQ_MAX] - bands = librosa.fft_frequencies(sr=sr, n_fft=n_fft) + bands = librosa.fft_frequencies(sr=sr, n_fft=n_fft) # band center frequencies + delta_f = bands[1] - bands[0] # bandwidth goods = [] for index in range(len(bands)): - band_min = bands[index] - band_max = bands[index + 1] if index < len(bands) - 1 else np.inf + band_min = bands[index] - delta_f / 2.0 + band_max = bands[index] + delta_f / 2.0 + # accept bands with any part of their range within interval [FREQ_MIN, FREQ_MAX] if FREQ_MIN <= band_max and band_min <= FREQ_MAX: goods.append(index) min_index = min(goods) - max_index = max(goods) + 1 + max_index = max(goods) # Return only valid frequency bands - stft_db = stft_db[min_index:max_index, :] - bands = bands[min_index:max_index] + stft_db = stft_db[min_index : max_index + 1, :] + bands = bands[min_index : max_index + 1] waveplot = generate_waveplot(waveform, stft_db, hop_length=hop_length) - return stft_db, waveplot, sr, bands, duration, min_index + return stft_db, waveplot, sr, bands, duration, min_index, time_vec def gain_stft(stft_db, gain_db=80.0, autogain_stddev=5.0): @@ -331,11 +339,20 @@ def normalize_stft(data, value=1.0, dtype=None): return data -def calculate_window_and_stride(stft_db, duration, window_size_ms=12, strides_per_window=3): +def calculate_window_and_stride( + stft_db, duration, window_size_ms=12, strides_per_window=3, time_vec=None +): # Create a horizontal (time) sliding window of Numpy views # Window: ~12ms # Stride: ~4ms - window = stft_db.shape[1] / (duration * 1e3) * window_size_ms + if time_vec is not None: + # use the precise center time per STFT column if provided + delta_t = time_vec[1] - time_vec[0] + window = window_size_ms / delta_t / 1e3 + else: + # estimate the window size based on audio file length and STFT length + window = stft_db.shape[1] / (duration * 1e3) * window_size_ms + stride = window / strides_per_window window = int(round(window)) @@ -345,7 +362,7 @@ def calculate_window_and_stride(stft_db, duration, window_size_ms=12, strides_pe def create_coarse_candidates(stft_db, window, stride, threshold_stddev=3.0): - # Re-calculate the non-zero median DB and MAD + # Re-calculate the non-zero median DB and MAD (scaled like std) temp = stft_db[stft_db > 0] med_db = np.median(temp) std_db = scipy.stats.median_abs_deviation(temp, axis=None, scale='normal') @@ -375,11 +392,11 @@ def create_coarse_candidates(stft_db, window, stride, threshold_stddev=3.0): def filter_candidates_to_ranges( stft_db, candidates, window=16, skew_stddev=2.0, area_percent=0.10, output_path=None ): - # Filter the candidates based on their Normal distribution skewness + # Filter the candidates based on their distribution skewness stride_ = 2 buffer = int(round(window / stride_ / 2)) - idxs = [] + reject_idxs = [] ranges = [] for index, (idx, start, stop) in tqdm.tqdm(list(enumerate(candidates))): # Extract the candidate window of the STFT @@ -393,7 +410,7 @@ def filter_candidates_to_ranges( # Center and clip the skew values skew_thresh = calculate_mean_within_stddev_window(skews, skew_stddev) - # IMPROTANT: Only center positive (right-sided) global skew for the global candidate calculation + # IMPORTANT: Only center positive (right-sided) global skew for the global candidate calculation skew_thresh = max(0, skew_thresh) skews = normalize_skew(skews, skew_thresh) @@ -438,13 +455,13 @@ def filter_candidates_to_ranges( ) plt.close('all') else: - idxs.append(idx) + reject_idxs.append(idx) - return ranges, idxs + return ranges, reject_idxs def plot_chirp_candidates( - stft_db, candidate_dbs, ranges, idxs, output_path='.', output_filename='candidates.png' + stft_db, candidate_dbs, ranges, reject_idxs, output_path='.', output_filename='candidates.png' ): if output_path is None: return @@ -453,7 +470,7 @@ def plot_chirp_candidates( cv2.imwrite(join(output_path, f'chirp.{index}.png'), stft_db[:, start:stop]) candidate_dbs_ = candidate_dbs.copy() - candidate_dbs[idxs] = np.nan + candidate_dbs[reject_idxs] = np.nan flags = np.isnan(candidate_dbs) candidate_dbs[flags] = np.nanmin(candidate_dbs) @@ -572,14 +589,6 @@ def tighten_ranges( island_stop += start ranges_.append((island_start, island_stop)) - x = candidate.copy() - for a, b in islands_plotting: - x[:, a] = np.iinfo(x.dtype).max - x[:, b] = np.iinfo(x.dtype).max - # candidate[:, a] = np.iinfo(candidate.dtype).max - # candidate[:, b] = np.iinfo(candidate.dtype).max - # cv2.imwrite('temp2.tif', x) - if output_path: # Plot the skew and spectrogram plt.figure() @@ -710,7 +719,9 @@ def normalize_contour(segment, index, dtype=None, blur=True, kernel=5, output_pa if blur: # segment = cv2.erode(segment, np.ones((3, 3), np.uint8), iterations=1) - segment = cv2.GaussianBlur(segment, (kernel, kernel), cv2.BORDER_DEFAULT) + segment = cv2.GaussianBlur( + segment, (kernel, kernel), sigmaX=4, sigmaY=4, borderType=cv2.BORDER_DEFAULT + ) segment = normalize_stft(segment, None, dtype) @@ -749,7 +760,8 @@ def find_contour_connected_components(segment, index, locations, sequence=4, out def find_harmonic(segmentmask, index, freq_offset, kernel=15, output_path='.'): h = segmentmask.shape[0] locations = np.array(np.where(segmentmask)) - locations[0] = h - ((h - locations[0]) * 2 + freq_offset) + # convert mask to first harmonic (doubled frequency), accounting for flipped frequency axis + locations[0] = h - ((h - locations[0] + freq_offset) * 2) flags = np.logical_and(0 <= locations[0], locations[0] < h) locations = locations[:, flags] @@ -760,7 +772,9 @@ def find_harmonic(segmentmask, index, freq_offset, kernel=15, output_path='.'): harmonic[tuple(locations)] = True harmonic_ = harmonic.astype(np.uint8) * 255 - harmonic_ = cv2.GaussianBlur(harmonic_, (kernel, kernel), cv2.BORDER_DEFAULT) + harmonic_ = cv2.GaussianBlur( + harmonic_, (kernel, kernel), sigmaX=4, sigmaY=4, borderType=cv2.BORDER_DEFAULT + ) write_contour_debug_image(harmonic_, index, 7, 'harmonic', output_path) return harmonic @@ -781,7 +795,9 @@ def find_echo(segmentmask, index, kernel=15, output_path='.'): echo[maxy, maxx:] = True echo_ = echo.astype(np.uint8) * 255 - echo_ = cv2.GaussianBlur(echo_, (kernel, kernel), cv2.BORDER_DEFAULT) + echo_ = cv2.GaussianBlur( + echo_, (kernel, kernel), sigmaX=4, sigmaY=4, borderType=cv2.BORDER_DEFAULT + ) write_contour_debug_image(echo_, index, 7, 'echo', output_path) return echo @@ -793,13 +809,18 @@ def remove_harmonic_and_echo( combined = np.logical_or(harmonic, echo) combined_ = combined.astype(np.uint8) * 255 - combined_ = cv2.GaussianBlur(combined_, (kernel, kernel), cv2.BORDER_DEFAULT) + combined_ = ( + cv2.GaussianBlur( + combined_, (kernel, kernel), sigmaX=4, sigmaY=4, borderType=cv2.BORDER_DEFAULT + ) + / 255.0 + ) write_contour_debug_image(combined_, index, 7, 'combined', output_path) dtype = segment.dtype segment = segment.astype(np.float32) - segment *= 1.0 - combined.astype(np.float32) + segment *= 1.0 - combined_.astype(np.float32) if None not in {med_db, std_db}: segment_threshold = med_db - std_db @@ -830,7 +851,7 @@ def calculate_astar_grid_and_endpoints( costs = segment.copy() segmentmask_ = np.logical_not(segmentmask) costs[segmentmask_] = 0 - write_contour_debug_image(costs, index, 7, 'costs', output_path=output_path) + write_contour_debug_image(costs, index, 8, 'costs', output_path=output_path) ys, xs = np.where(costs > 0) points = np.stack([ys, xs], axis=1, dtype=np.float32) @@ -857,15 +878,15 @@ def calculate_astar_grid_and_endpoints( grid += 1 assert grid.min() > 0 - if output_path: - bounds = np.where(np.sum(costs, axis=0) > 0) - left = int(np.min(bounds)) - right = int(np.max(bounds)) - bounds = np.where(np.sum(costs, axis=1) > 0) - top = int(np.min(bounds)) - bottom = int(np.max(bounds)) - boundary = (top, bottom, left, right) + bounds = np.where(np.sum(costs, axis=0) > 0) + left = int(np.min(bounds)) + right = int(np.max(bounds)) + bounds = np.where(np.sum(costs, axis=1) > 0) + top = int(np.min(bounds)) + bottom = int(np.max(bounds)) + boundary = (top, bottom, left, right) + if output_path: height, width = costs.shape value = np.iinfo(canvas.dtype).max @@ -879,11 +900,13 @@ def calculate_astar_grid_and_endpoints( cv2.circle(canvas, begin[::-1], 5, (0, value, 0), -1) cv2.circle(canvas, end[::-1], 5, (0, 0, value), -1) - write_contour_debug_image(canvas, index, 7, 'endpoints', output_path=output_path) + write_contour_debug_image(canvas, index, 8, 'endpoints', output_path=output_path) costs = segment.astype(np.float32) segmentmask_ = segmentmask.astype(np.float32) - segmentmask_ = cv2.GaussianBlur(segmentmask_, (kernel, kernel), cv2.BORDER_DEFAULT) + segmentmask_ = cv2.GaussianBlur( + segmentmask_, (kernel, kernel), sigmaX=4, sigmaY=4, borderType=cv2.BORDER_DEFAULT + ) costs *= segmentmask_ costs = normalize_stft(costs, None, segment.dtype) @@ -930,9 +953,11 @@ def extract_contour_keypoints( der1 = np.nan_to_num(der1, nan=der1min, posinf=der1max, neginf=der1min) der1 = gaussian_filter1d(der1, contour_smoothing_sigma, mode='nearest') + # Retrieve first (knee) and last (heel) locations where slope (dy/dx) magnitude approaches the median value slope_thresh = np.abs(np.median(der1)) slope_flags = np.abs(der1) <= slope_thresh knee_idx, heel_idx = get_slope_islands(slope_flags) + # Retrieve location of minimum slope magnitude between knee and heel fc_idx = knee_idx + int(np.argmin(np.abs(der1[knee_idx:heel_idx]))) if output_path: @@ -1044,7 +1069,7 @@ def significant_contour_path( def scale_pdf_contour(segment, index, output_path='.'): segment = normalize_stft(segment, None, segment.dtype) - med_db, std_db, (peak_db, peak_db_std) = plot_histogram( + med_db, std_db, peak_db = plot_histogram( segment, smoothing=512, ignore_zeros=True, @@ -1055,7 +1080,7 @@ def scale_pdf_contour(segment, index, output_path='.'): assert segment.min() == 0 assert segment.max() == np.iinfo(segment.dtype).max - dist = scipy.stats.norm(peak_db, peak_db_std) + dist = scipy.stats.norm(peak_db, std_db) steps = segment.max() x = np.linspace(0, steps, steps) y = dist.pdf(x) @@ -1070,44 +1095,45 @@ def scale_pdf_contour(segment, index, output_path='.'): if np.any(np.isnan(y)): return segment, None, None - # Plot the histogram - plt.figure(figsize=(7, 7)) - plt.title('Inverse PDF Scaling', y=1.16) - plt.xlim([segment.min(), segment.max()]) - plt.ylim([-0.01, 1.01]) - plt.xlabel('Frequency') - plt.ylabel('Probability') - - plt.axvspan(peak_db - 3 * peak_db_std, peak_db + 3 * peak_db_std, color='grey', alpha=0.15) - plt.axvspan(peak_db - 2 * peak_db_std, peak_db + 2 * peak_db_std, color='grey', alpha=0.15) - plt.axvspan( - peak_db - 1 * peak_db_std, - peak_db + 1 * peak_db_std, - color='grey', - alpha=0.15, - label='Standard Deviations σ={1,2,3}', - ) - plt.plot( - [peak_db] * 2, [0, 1], color='orange', linestyle='--', label='Peak Histogram Frequency' - ) - plt.axhline(0, color='black', linestyle='--', alpha=0.5) - plt.axhline(1, color='black', linestyle='--', alpha=0.5) - plt.plot(x, scaling, label='Weighting') + if output_path: + # Plot the histogram + plt.figure(figsize=(7, 7)) + plt.title('Inverse PDF Scaling', y=1.16) + plt.xlim([segment.min(), segment.max()]) + plt.ylim([-0.01, 1.01]) + plt.xlabel('Frequency') + plt.ylabel('Probability') + + plt.axvspan(peak_db - 3 * std_db, peak_db + 3 * std_db, color='grey', alpha=0.15) + plt.axvspan(peak_db - 2 * std_db, peak_db + 2 * std_db, color='grey', alpha=0.15) + plt.axvspan( + peak_db - 1 * std_db, + peak_db + 1 * std_db, + color='grey', + alpha=0.15, + label='Standard Deviations σ={1,2,3}', + ) + plt.plot( + [peak_db] * 2, [0, 1], color='orange', linestyle='--', label='Peak Histogram Frequency' + ) + plt.axhline(0, color='black', linestyle='--', alpha=0.5) + plt.axhline(1, color='black', linestyle='--', alpha=0.5) + plt.plot(x, scaling, label='Weighting') - plt.legend( - bbox_to_anchor=(0.0, 1.02, 1.0, 0.102), - loc=3, - ncol=1, - mode='expand', - borderaxespad=0.0, - ) + plt.legend( + bbox_to_anchor=(0.0, 1.02, 1.0, 0.102), + loc=3, + ncol=1, + mode='expand', + borderaxespad=0.0, + ) - plt.savefig( - join(output_path, f'contour.{index}.00.histogram.scaling.png'), - dpi=150, - bbox_inches='tight', - ) - plt.close('all') + plt.savefig( + join(output_path, f'contour.{index}.00.histogram.scaling.png'), + dpi=150, + bbox_inches='tight', + ) + plt.close('all') scaling = np.hstack((scaling, scaling[-1:])) mask = scaling[segment - segment.min()] @@ -1116,7 +1142,7 @@ def scale_pdf_contour(segment, index, output_path='.'): write_contour_debug_image(segment, index, 1, 'cdf', output_path) - return segment, peak_db, peak_db_std + return segment, peak_db, std_db def morph_open_contour(segment, index, output_path='.'): @@ -1128,21 +1154,35 @@ def morph_open_contour(segment, index, output_path='.'): def find_contour_and_peak( - segment, index, max_locations, peak_db, peak_db_std, threshold_std=2.0, sigma=5, output_path='.' + segment, + index, + max_locations, + peak_db=None, + peak_db_std=None, + threshold_std=2.0, + sigma=5, + output_path='.', + threshold=None, ): - threshold = peak_db - threshold_std * peak_db_std + + if not threshold: + # Apply threshold equal to normalized (and smoothed) segment histogram mode, + # minus the estimated noise standard deviation scaled by threshold_std + # (note that these were computed prior to CDF weighting) + threshold = peak_db - threshold_std * peak_db_std contours = measure.find_contours( segment, level=threshold, fully_connected='high', positive_orientation='high' ) # Display the image and plot all contours found - fig, ax = plt.subplots() - ax.imshow(segment, cmap=plt.cm.gray) - ax.set_xticks([]) - ax.set_xticklabels([]) - ax.set_yticks([]) - ax.set_yticklabels([]) + if output_path: + fig, ax = plt.subplots() + ax.imshow(segment, cmap=plt.cm.gray) + ax.set_xticks([]) + ax.set_xticklabels([]) + ax.set_yticks([]) + ax.set_yticklabels([]) max_points = [Point(*value) for value in max_locations] counter = {} @@ -1154,10 +1194,12 @@ def find_contour_and_peak( if polygon.contains(max_point): found.append(max_location) if len(found) > 0: - ax.plot(contour[:, 1], contour[:, 0], linewidth=2) x = gaussian_filter1d(contour[:, 1], sigma) y = gaussian_filter1d(contour[:, 0], sigma) - ax.plot(x, y, linewidth=1, linestyle='--') + + if output_path: + ax.plot(contour[:, 1], contour[:, 0], linewidth=2) + ax.plot(x, y, linewidth=1, linestyle='--') contour_ = np.vstack((y, x), dtype=contour.dtype).T polygon_ = Polygon(contour).convex_hull @@ -1167,13 +1209,14 @@ def find_contour_and_peak( rr, cc = draw.polygon(contour_[:, 0], contour_[:, 1], shape=segment.shape) segmentmask[rr, cc] = True - plt.savefig( - join(output_path, f'contour.{index}.05.contour.png'), - dpi=150, - pad_inches=0, - bbox_inches='tight', - ) - plt.close('all') + if output_path: + plt.savefig( + join(output_path, f'contour.{index}.05.contour.png'), + dpi=150, + pad_inches=0, + bbox_inches='tight', + ) + plt.close('all') # segmentmask = np.ones(segment.shape, dtype=bool) @@ -1204,7 +1247,9 @@ def calculate_harmonic_and_echo_flags( nonzeros = original > 0 negative = ~np.logical_or(np.logical_or(harmonic, echo), segmentmask) negative_ = negative.astype(np.uint8) * 255 - negative_ = cv2.GaussianBlur(negative_, (kernel, kernel), cv2.BORDER_DEFAULT) + negative_ = cv2.GaussianBlur( + negative_, (kernel, kernel), sigmaX=4, sigmaY=4, borderType=cv2.BORDER_DEFAULT + ) write_contour_debug_image(negative_, index, 7, 'negative', output_path=output_path) negative_skew = scipy.stats.skew(original[np.logical_and(nonzeros, negative)]) @@ -1214,17 +1259,18 @@ def calculate_harmonic_and_echo_flags( - negative_skew ) - shew_thesh = np.abs(negative_skew * 0.1) - harmonic_flag = harmonic_skew >= shew_thesh - echo_flag = echo_skew >= shew_thesh + skew_thresh = np.abs(negative_skew * 0.1) + harmonic_flag = harmonic_skew >= skew_thresh + echo_flag = echo_skew >= skew_thresh harmonic_peak = None if harmonic_flag: - temp = canvas.copy() - temp[:, :, 2][harmonic] = np.iinfo(original.dtype).max - canvas = np.around( - (canvas.astype(np.float32) * 0.5) + (temp.astype(np.float32) * 0.5) - ).astype(canvas.dtype) + if output_path: + temp = canvas.copy() + temp[:, :, 2][harmonic] = np.iinfo(original.dtype).max + canvas = np.around( + (canvas.astype(np.float32) * 0.5) + (temp.astype(np.float32) * 0.5) + ).astype(canvas.dtype) try: temp = original.copy() temp[~harmonic] = 0 @@ -1235,11 +1281,12 @@ def calculate_harmonic_and_echo_flags( echo_peak = None if echo_flag: - temp = canvas.copy() - temp[:, :, 0][echo] = np.iinfo(original.dtype).max - canvas = np.around( - (canvas.astype(np.float32) * 0.5) + (temp.astype(np.float32) * 0.5) - ).astype(canvas.dtype) + if output_path: + temp = canvas.copy() + temp[:, :, 0][echo] = np.iinfo(original.dtype).max + canvas = np.around( + (canvas.astype(np.float32) * 0.5) + (temp.astype(np.float32) * 0.5) + ).astype(canvas.dtype) try: temp = original.copy() temp[~echo] = 0 @@ -1253,7 +1300,7 @@ def calculate_harmonic_and_echo_flags( @lp def compute_wrapper( - wav_filepath, annotations=None, output_folder='.', bitdepth=16, debug=True, **kwargs + wav_filepath, annotations=None, output_folder='.', bitdepth=16, debug=False, **kwargs ): """ Compute the spectrograms for a given input WAV and saves them to disk. @@ -1285,7 +1332,7 @@ def compute_wrapper( debug_path = get_debug_path(output_folder, wav_filepath, enabled=debug) # Load the spectrogram from a WAV file on disk - stft_db, waveplot, sr, bands, duration, freq_offset = load_stft(wav_filepath) + stft_db, waveplot, sr, bands, duration, freq_offset, time_vec = load_stft(wav_filepath) # Apply a dynamic range to a fixed dB range stft_db = gain_stft(stft_db) @@ -1298,33 +1345,39 @@ def compute_wrapper( stft_db = np.ascontiguousarray(stft_db[::-1, :]) bands = bands[::-1] y_step_freq = float(bands[0] - bands[1]) - x_step_ms = float((1e3 * duration) / stft_db.shape[1]) + x_step_ms = float(1e3 * (time_vec[1] - time_vec[0])) + bands = np.around(bands).astype(np.int32).tolist() # # Save the spectrogram image to disk # cv2.imwrite('debug.tif', stft_db, [cv2.IMWRITE_TIFF_COMPRESSION, 1]) - # Plot the historgram, ignoring any non-zero values (will no-op if output_path is None) - plot_histogram(stft_db, ignore_zeros=True, output_path=debug_path) + # Plot the histogram, ignoring any non-zero values (will no-op if output_path is None) + global_med_db, global_std_db, global_peak_db = plot_histogram( + stft_db, ignore_zeros=True, smoothing=512, output_path=debug_path + ) + # Estimate a global threshold for finding the edges of bat call contours + global_threshold_std = 2.0 + global_threshold = global_peak_db - global_threshold_std * global_std_db # Get a distribution of the max candidate locations - window, stride = calculate_window_and_stride(stft_db, duration) - candidates, candidate_dbs = create_coarse_candidates(stft_db, window, stride) + window, stride = calculate_window_and_stride(stft_db, duration, time_vec=time_vec) + candidates, candidate_max_dbs = create_coarse_candidates(stft_db, window, stride) # Filter all candidates to the ranges that have a substantial right-side skew - ranges, idxs = filter_candidates_to_ranges(stft_db, candidates, output_path=debug_path) + ranges, reject_idxs = filter_candidates_to_ranges(stft_db, candidates, output_path=debug_path) # Add in user-specified annotations to ranges if annotations: for start, stop in annotations: - start_px = int(np.around((start / duration) * stft_db.shape[1])) - stop_px = int(np.around((stop / duration) * stft_db.shape[1])) + 1 + start_px = int(np.argmin(np.abs(time_vec - start))) + stop_px = int(np.argmin(np.abs(time_vec - stop)) + 1) ranges.append((start_px, stop_px)) # Merge all range segments into contiguous range blocks ranges = merge_ranges(ranges, stft_db.shape[1]) # Plot the chirp candidates (will no-op if output_path is None) - plot_chirp_candidates(stft_db, candidate_dbs, ranges, idxs, output_path=debug_path) + plot_chirp_candidates(stft_db, candidate_max_dbs, ranges, reject_idxs, output_path=debug_path) # Tighten the ranges by looking for substantial right-side skew (use stride for a smaller sampling window) ranges = tighten_ranges(stft_db, ranges, stride, duration, output_path=debug_path) @@ -1364,8 +1417,15 @@ def compute_wrapper( # segment, med_db, std_db, peak_db = threshold_contour(segment, index, output_path=debug_path) # Step 5 - Find primary contour that contains max amplitude - segmentmask, peak, threshold = find_contour_and_peak( - segment, index, max_locations, peak_db, peak_db_std, output_path=debug_path + # (To use a local instead of global threshold, remove the threshold argument here) + segmentmask, peak, segment_threshold = find_contour_and_peak( + segment, + index, + max_locations, + peak_db, + peak_db_std, + output_path=debug_path, + threshold=global_threshold, ) if peak is None: @@ -1388,35 +1448,37 @@ def compute_wrapper( original, index, segmentmask, harmonic, echo, canvas, output_path=debug_path ) - # Step 8 - Remove harmonic and echo from segmentation + # Remove harmonic and echo from segmentation segment = remove_harmonic_and_echo( - segment, index, harmonic, echo, threshold, output_path=debug_path + segment, index, harmonic, echo, global_threshold, output_path=debug_path ) - # Step 7 - Calculate the A* cost grid and start/end points - costs, grid, begin, end, boundary = calculate_astar_grid_and_endpoints( + # Step 8 - Calculate the A* cost grid and bat call start/end points + costs, grid, call_begin, call_end, boundary = calculate_astar_grid_and_endpoints( segment, index, segmentmask, peak, canvas, output_path=debug_path ) top, bottom, left, right = boundary # Skip chirp if the extracted path covers a small duration or bandwidth bandwidth, duration_, significant = significant_contour_path( - begin, end, y_step_freq, x_step_ms + call_begin, call_end, y_step_freq, x_step_ms ) if not significant: continue - # Step 8 - Extract optimal path from start to end using the cost grid - path = extract_contour_path(grid, begin, end, canvas, index, output_path=debug_path) + # Step 9 - Extract optimal path from start to end using the cost grid + path = extract_contour_path( + grid, call_begin, call_end, canvas, index, output_path=debug_path + ) - # Step 9 - Extract contour keypoints + # Step 10 - Extract contour keypoints path_smoothed, (knee, fc, heel), slopes = extract_contour_keypoints( path, canvas, index, peak, output_path=debug_path ) - # Step 10 - Collect chirp metadata + # Step 11 - Collect chirp metadata metadata = { - 'curve.(khz,ms)': [ + 'curve.(hz,ms)': [ ( bands[y], (start + x) * x_step_ms, @@ -1426,23 +1488,24 @@ def compute_wrapper( 'start.ms': (start + left) * x_step_ms, 'end.ms': (start + right) * x_step_ms, 'duration.ms': (right - left) * x_step_ms, + 'threshold.amp': int(round(255.0 * (segment_threshold / np.iinfo(stft_db.dtype).max))), 'peak f.ms': (start + peak[1]) * x_step_ms, 'fc.ms': (start + bands[fc[1]]) * x_step_ms, 'hi fc:knee.ms': (start + bands[knee[1]]) * x_step_ms, 'lo fc:heel.ms': (start + bands[heel[1]]) * x_step_ms, - 'bandwidth.khz': bandwidth, - 'hi f.khz': bands[top], - 'lo f.khz': bands[bottom], - 'peak f.khz': bands[peak[0]], - 'fc.khz': bands[fc[0]], - 'hi fc:knee.khz': bands[knee[0]], - 'lo fc:heel.khz': bands[heel[0]], + 'bandwidth.hz': bandwidth, + 'hi f.hz': bands[top], + 'lo f.hz': bands[bottom], + 'peak f.hz': bands[peak[0]], + 'fc.hz': bands[fc[0]], + 'hi fc:knee.hz': bands[knee[0]], + 'lo fc:heel.hz': bands[heel[0]], 'harmonic.flag': harmonic_flag, 'harmonic peak f.ms': (start + hamonic_peak[1]) * x_step_ms if harmonic_flag else None, - 'harmonic peak f.khz': bands[hamonic_peak[0]] if harmonic_flag else None, + 'harmonic peak f.hz': bands[hamonic_peak[0]] if harmonic_flag else None, 'echo.flag': echo_flag, 'echo peak f.ms': (start + echo_peak[1]) * x_step_ms if echo_flag else None, - 'echo peak f.khz': bands[echo_peak[0]] if echo_flag else None, + 'echo peak f.hz': bands[echo_peak[0]] if echo_flag else None, } metadata.update(slopes) @@ -1450,18 +1513,20 @@ def compute_wrapper( for key, value in list(metadata.items()): if value is None: continue - if key.endswith('.ms') or key.endswith('.khz'): + if key.endswith('.ms'): metadata[key] = round(float(value), 3) + if key.endswith('.hz'): + metadata[key] = int(round(value)) if key.endswith('.flag'): metadata[key] = bool(value) if key.endswith('.y_px/x_px'): key_ = key.replace('.y_px/x_px', '.khz/ms') - metadata[key_] = round(float(value * (y_step_freq / x_step_ms)), 9) + metadata[key_] = round(float(value * ((y_step_freq / 1000.0) / x_step_ms)), 3) metadata.pop(key) - if key.endswith('.(khz,ms)'): + if key.endswith('.(hz,ms)'): metadata[key] = [ ( - round(float(val1), 3), + int(round(val1)), round(float(val2), 3), ) for val1, val2 in value @@ -1469,68 +1534,91 @@ def compute_wrapper( metas.append(metadata) - trim_ms = 1.0 - trim = int(round(trim_ms / x_step_ms)) - trim_begin = max(0, min(segment.shape[1], begin[1] - trim)) - trim_end = max(0, min(segment.shape[1], end[1] + trim)) + # Trim segment around the bat call with a small buffer + buffer_ms = 1.0 + buffer_pix = int(round(buffer_ms / x_step_ms)) + trim_begin = max(0, min(segment.shape[1], call_begin[1] - buffer_pix)) + trim_end = max(0, min(segment.shape[1], call_end[1] + buffer_pix)) segments['stft_db'].append(stft_db[:, start + trim_begin : start + trim_end]) segments['waveplot'].append(waveplot[:, start + trim_begin : start + trim_end]) segments['costs'].append(costs[:, trim_begin:trim_end]) - segments['canvas'].append(canvas[:, trim_begin:trim_end]) + if debug_path: + segments['canvas'].append(canvas[:, trim_begin:trim_end]) - for key in segments: + # Concatenate extracted, trimmed segments and other matrices + for key in list(segments.keys()): value = segments[key] - segments[key] = np.hstack(value) if len(value) > 0 else None.copy() + if len(value) == 0: + segments.pop(key) + continue + segments[key] = np.hstack(value) if debug_path: cv2.imwrite(join(debug_path, 'spectrogram.tif'), stft_db, [cv2.IMWRITE_TIFF_COMPRESSION, 1]) - cv2.imwrite( - join(debug_path, 'spectrogram.compressed.tif'), - segments['stft_db'], - [cv2.IMWRITE_TIFF_COMPRESSION, 1], - ) cv2.imwrite(join(debug_path, 'spectrogram.waveplot.png'), waveplot) - cv2.imwrite(join(debug_path, 'spectrogram.compressed.waveplot.png'), segments['waveplot']) - temp_top = np.stack((segments['stft_db'], segments['stft_db'], segments['stft_db']), axis=2) - temp_bot = cv2.resize( - segments['waveplot'], temp_top.shape[:2][::-1], interpolation=cv2.INTER_LINEAR - ) - temp_bot = temp_bot.astype(np.float32) * ( - np.iinfo(temp_top.dtype).max / np.iinfo(temp_bot.dtype).max - ) - temp_bot = np.around(temp_bot).astype(temp_top.dtype) - temp = np.vstack((temp_top, temp_bot)) - cv2.imwrite(join(debug_path, 'spectrogram.compressed.combined.png'), temp) - cv2.imwrite( - join(debug_path, 'spectrogram.compressed.threshold.tif'), - segments['costs'], - [cv2.IMWRITE_TIFF_COMPRESSION, 1], - ) - temp = segments['costs'].copy() - flags = segments['costs'] == 0 - temp = normalize_stft(temp, None, np.uint8) - temp = cv2.applyColorMap(temp, cv2.COLORMAP_JET) - temp[:, :, 0][flags] = 0 - temp[:, :, 1][flags] = 0 - temp[:, :, 2][flags] = 0 - cv2.imwrite( - join(debug_path, 'spectrogram.compressed.threshold.jet.tif'), - temp, - [cv2.IMWRITE_TIFF_COMPRESSION, 1], - ) - cv2.imwrite( - join(debug_path, 'spectrogram.compressed.keypoints.tif'), - segments['canvas'], - [cv2.IMWRITE_TIFF_COMPRESSION, 1], - ) + + if 'stft_db' in segments: + cv2.imwrite( + join(debug_path, 'spectrogram.compressed.tif'), + segments['stft_db'], + [cv2.IMWRITE_TIFF_COMPRESSION, 1], + ) + + if 'waveplot' in segments: + cv2.imwrite( + join(debug_path, 'spectrogram.compressed.waveplot.png'), segments['waveplot'] + ) + + if 'stft_db' in segments and 'waveplot' in segments: + temp_top = np.stack( + (segments['stft_db'], segments['stft_db'], segments['stft_db']), axis=2 + ) + temp_bot = cv2.resize( + segments['waveplot'], temp_top.shape[:2][::-1], interpolation=cv2.INTER_LINEAR + ) + temp_bot = temp_bot.astype(np.float32) * ( + np.iinfo(temp_top.dtype).max / np.iinfo(temp_bot.dtype).max + ) + temp_bot = np.around(temp_bot).astype(temp_top.dtype) + temp = np.vstack((temp_top, temp_bot)) + cv2.imwrite(join(debug_path, 'spectrogram.compressed.combined.png'), temp) + + if 'costs' in segments: + cv2.imwrite( + join(debug_path, 'spectrogram.compressed.threshold.tif'), + segments['costs'], + [cv2.IMWRITE_TIFF_COMPRESSION, 1], + ) + temp = segments['costs'].copy() + flags = segments['costs'] == 0 + temp = normalize_stft(temp, None, np.uint8) + temp = cv2.applyColorMap(temp, cv2.COLORMAP_JET) + temp[:, :, 0][flags] = 0 + temp[:, :, 1][flags] = 0 + temp[:, :, 2][flags] = 0 + cv2.imwrite( + join(debug_path, 'spectrogram.compressed.threshold.jet.tif'), + temp, + [cv2.IMWRITE_TIFF_COMPRESSION, 1], + ) + + if 'canvas' in segments: + cv2.imwrite( + join(debug_path, 'spectrogram.compressed.keypoints.tif'), + segments['canvas'], + [cv2.IMWRITE_TIFF_COMPRESSION, 1], + ) output_paths = [] compressed_paths = [] datas = [ (output_paths, 'jpg', stft_db), - (compressed_paths, 'compressed.jpg', segments['stft_db']), ] + if 'stft_db' in segments: + datas += [ + (compressed_paths, 'compressed.jpg', segments['stft_db']), + ] for accumulator, tag, data in datas: if data.dtype != np.uint8: @@ -1559,29 +1647,33 @@ def compute_wrapper( metadata = { 'wav.path': wav_filepath, 'spectrogram': { - 'true.path': output_paths, + 'uncompressed.path': output_paths, 'compressed.path': compressed_paths, }, - 'threshold.amp': int(round(255.0 * (threshold / max_value))), - 'sr.khz': sr, + 'global_threshold.amp': int(round(255.0 * (global_threshold / max_value))), + 'sr.hz': int(sr), 'duration.ms': round(duration * 1e3, 3), 'frequencies': { - 'min.khz': int(FREQ_MIN), - 'max.khz': int(FREQ_MAX), - 'pixels.khz': [round(float(band), 3) for band in bands], + 'min.hz': int(FREQ_MIN), + 'max.hz': int(FREQ_MAX), + 'pixels.hz': bands, }, 'size': { - 'true': { + 'uncompressed': { 'width.px': stft_db.shape[1], 'height.px': stft_db.shape[0], }, - 'compressed': { - 'width.px': segments['stft_db'].shape[1], - 'height.px': segments['stft_db'].shape[0], - }, + 'compressed': None, }, 'segments': metas, } + if 'stft_db' in segments: + metadata['size']['compressed'] = ( + { + 'width.px': segments['stft_db'].shape[1], + 'height.px': segments['stft_db'].shape[0], + }, + ) metadata_path = join(output_folder, f'{base}.metadata.json') with open(metadata_path, 'w') as metafile: diff --git a/pyproject.toml b/pyproject.toml index d1e6ae6..9787c3b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,2 +1,3 @@ [build-system] requires = ["setuptools", "wheel"] +build-backend = "setuptools.build_meta" diff --git a/requirements/optional.txt b/requirements/optional.txt index c9f61f3..8dcdc59 100644 --- a/requirements/optional.txt +++ b/requirements/optional.txt @@ -2,10 +2,11 @@ black codecov coverage flake8 -imageio +# gradio ipython -line_profiler -onnx +isort +# onnx +# onnxruntime pre-commit pytest pytest-benchmark[histogram] @@ -15,7 +16,12 @@ pytest-profiling pytest-random-order pytest-sugar pytest-xdist +pyupgrade pyyaml +rstcheck +rstcheck[sphinx] Sphinx>=5,<6 sphinx_rtd_theme +# torch +# torchvision xdoctest diff --git a/requirements/runtime.txt b/requirements/runtime.txt index c0f4bc6..9f94278 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -1,19 +1,16 @@ click -codecov -coverage cryptography cython -gradio -imgaug librosa +line_profiler matplotlib numpy -onnxruntime opencv-python-headless Pillow pooch +pyastar2d @ git+https://github.com/bluemellophone/batbot-pyastar2d@master rich +scikit-image +shapely sphinx-click -torch -torchvision tqdm diff --git a/setup.cfg b/setup.cfg index cd2fb4f..d73eb80 100644 --- a/setup.cfg +++ b/setup.cfg @@ -16,26 +16,23 @@ project_urls = [options] packages = find: platforms = any -include_package_data = True +include_package_data = true install_requires = click - codecov - coverage cryptography cython - gradio - imgaug librosa + line_profiler matplotlib numpy - onnxruntime opencv-python-headless Pillow pooch + pyastar2d @ git+https://github.com/bluemellophone/batbot-pyastar2d@master rich + scikit-image + shapely sphinx-click - torch - torchvision tqdm python_requires = >=3.7 @@ -43,12 +40,6 @@ python_requires = >=3.7 console_scripts = batbot = batbot.batbot:cli -[bdist_wheel] -universal = 1 - -[aliases] -test=pytest - [tool:pytest] minversion = 5.4 addopts = -v -p no:doctest --xdoctest --xdoctest-style=google --random-order --random-order-bucket=global --cov=./ --cov-report html -m "not separate" --durations-min=1.0 --color=yes --code-highlight=yes --show-capture=log -ra diff --git a/tests/test_batbot.py b/tests/test_batbot.py index 3a1db9b..73bfea8 100644 --- a/tests/test_batbot.py +++ b/tests/test_batbot.py @@ -1,13 +1,5 @@ import batbot -def test_fetch(): - batbot.fetch(pull=False) - batbot.fetch(pull=True) - - batbot.fetch(pull=False, config='usgs') - batbot.fetch(pull=True, config='usgs') - - def test_example(): batbot.example() diff --git a/tests/test_spectrogram.py b/tests/test_spectrogram.py index 69f8fe5..dcf03d5 100644 --- a/tests/test_spectrogram.py +++ b/tests/test_spectrogram.py @@ -4,5 +4,5 @@ def test_spectrogram_compute(): from batbot.spectrogram import compute - wav_filepath = abspath(join('examples', 'example1.wav')) + wav_filepath = abspath(join('examples', 'example2.wav')) output_paths, metadata_path, metadata = compute(wav_filepath) diff --git a/tpl/pyastar2d/.github/workflows/python-publish-linux.yml b/tpl/pyastar2d/.github/workflows/python-publish-linux.yml deleted file mode 100644 index 5c0c809..0000000 --- a/tpl/pyastar2d/.github/workflows/python-publish-linux.yml +++ /dev/null @@ -1,53 +0,0 @@ -# This workflow will upload a Python Package using Twine when a release is created -# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries - -# This workflow uses actions that are not certified by GitHub. -# They are provided by a third-party and are governed by -# separate terms of service, privacy policy, and support -# documentation. - -name: Upload Python Package Linux - -on: - release: - types: [published] - -permissions: - contents: read - -jobs: - deploy: - - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v3 - - name: Set up Python - uses: actions/setup-python@v3 - with: - python-version: '3.7' - - name: Run tests - run: | - pip install . - pip install -r requirements-dev.txt - py.test - - name: Install dependencies - run: | - python -m pip install --upgrade pip - python -m pip install twine build - - name: Build source distribution - run: | - python -m build --sdist - - name: Build manylinux Python wheels - uses: RalfG/python-wheels-manylinux-build@v0.4.2-manylinux2014_x86_64 - with: - python-versions: 'cp37-cp37m cp38-cp38 cp39-cp39 cp310-cp310' - build-requirements: 'numpy' - - name: Publish distribution 📦 to Test PyPI - env: - TWINE_USERNAME: __token__ - #TWINE_PASSWORD: ${{ secrets.TEST_PYPI_API_TOKEN }} - TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} - run: | - #python -m twine upload --repository testpypi dist/*-manylinux*.whl dist/*.tar.gz --verbose --skip-existing - python -m twine upload dist/*-manylinux*.whl dist/*.tar.gz --verbose --skip-existing diff --git a/tpl/pyastar2d/.github/workflows/python-publish-macos.yml b/tpl/pyastar2d/.github/workflows/python-publish-macos.yml deleted file mode 100644 index bebb882..0000000 --- a/tpl/pyastar2d/.github/workflows/python-publish-macos.yml +++ /dev/null @@ -1,52 +0,0 @@ -# This workflow will upload a Python Package using Twine when a release is created -# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries - -# This workflow uses actions that are not certified by GitHub. -# They are provided by a third-party and are governed by -# separate terms of service, privacy policy, and support -# documentation. - -name: Upload Python Package MacOs - -on: - release: - types: [published] - -permissions: - contents: read - -jobs: - deploy: - - runs-on: macos-11 - - strategy: - matrix: - python-version: ["3.7", "3.8", "3.9", "3.10"] - - steps: - - uses: actions/checkout@v3 - - name: Set up Python - uses: actions/setup-python@v3 - with: - python-version: "${{ matrix.python-version }}" - - name: Run tests - run: | - pip install . - pip install -r requirements-dev.txt - py.test - - name: Install dependencies - run: | - python -m pip install --upgrade pip - python -m pip install wheel twine - - name: "Build package for python ${{ matrix.python-version }}" - run: | - python setup.py bdist_wheel - - name: Publish distribution 📦 to Test PyPI - env: - TWINE_USERNAME: __token__ - #TWINE_PASSWORD: ${{ secrets.TEST_PYPI_API_TOKEN }} - TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} - run: | - #python -m twine upload --repository testpypi dist/* --verbose --skip-existing - python -m twine upload dist/* --verbose --skip-existing diff --git a/tpl/pyastar2d/.github/workflows/python-publish-windows.yml b/tpl/pyastar2d/.github/workflows/python-publish-windows.yml deleted file mode 100644 index 9c0a51a..0000000 --- a/tpl/pyastar2d/.github/workflows/python-publish-windows.yml +++ /dev/null @@ -1,52 +0,0 @@ -# This workflow will upload a Python Package using Twine when a release is created -# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries - -# This workflow uses actions that are not certified by GitHub. -# They are provided by a third-party and are governed by -# separate terms of service, privacy policy, and support -# documentation. - -name: Upload Python Package Windows - -on: - release: - types: [published] - -permissions: - contents: read - -jobs: - deploy: - - runs-on: windows-2022 - - strategy: - matrix: - python-version: ["3.7", "3.8", "3.9", "3.10"] - - steps: - - uses: actions/checkout@v3 - - name: Set up Python - uses: actions/setup-python@v3 - with: - python-version: "${{ matrix.python-version }}" - - name: Run tests - run: | - pip install . - pip install -r requirements-dev.txt - py.test - - name: Install dependencies - run: | - python -m pip install --upgrade pip - python -m pip install wheel twine - - name: "Build package for python ${{ matrix.python-version }}" - run: | - python setup.py bdist_wheel - - name: Publish distribution 📦 to Test PyPI - env: - TWINE_USERNAME: __token__ - #TWINE_PASSWORD: ${{ secrets.TEST_PYPI_API_TOKEN }} - TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} - run: | - #python -m twine upload --repository testpypi dist/* --verbose --skip-existing - python -m twine upload dist/* --verbose --skip-existing diff --git a/tpl/pyastar2d/.gitignore b/tpl/pyastar2d/.gitignore deleted file mode 100644 index 2aaa22d..0000000 --- a/tpl/pyastar2d/.gitignore +++ /dev/null @@ -1,12 +0,0 @@ -*.swp -*.swo -*.so -*.pyc -*.png -*.cache -*.egg-info -*.eggs -.vscode -*.o -build/ -.DS_Store diff --git a/tpl/pyastar2d/.travis.yml b/tpl/pyastar2d/.travis.yml deleted file mode 100644 index 89a2808..0000000 --- a/tpl/pyastar2d/.travis.yml +++ /dev/null @@ -1,17 +0,0 @@ -language: python -python: - - 3.8 -before_install: - - pip install -U pip -# command to install dependencies -install: - - pip install -r requirements-dev.txt - - pip install coveralls # python-coveralls leads to this issue: https://github.com/z4r/python-coveralls/issues/73 - - pip install . -# command to run tests -script: - - pytest --cov pyastar2d --cov-report term-missing -after_success: - - coveralls -notifications: - email: false diff --git a/tpl/pyastar2d/LICENSE b/tpl/pyastar2d/LICENSE deleted file mode 100644 index 8f0a412..0000000 --- a/tpl/pyastar2d/LICENSE +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2020 Hendrik Weideman - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/tpl/pyastar2d/MANIFEST.in b/tpl/pyastar2d/MANIFEST.in deleted file mode 100644 index 4c89ab4..0000000 --- a/tpl/pyastar2d/MANIFEST.in +++ /dev/null @@ -1,2 +0,0 @@ -include requirements.txt -include src/cpp/experimental_heuristics.h diff --git a/tpl/pyastar2d/README.md b/tpl/pyastar2d/README.md deleted file mode 100644 index e51093a..0000000 --- a/tpl/pyastar2d/README.md +++ /dev/null @@ -1,158 +0,0 @@ -[![Build Status](https://travis-ci.com/hjweide/pyastar2d.svg?branch=master)](https://travis-ci.com/hjweide/pyastar2d) -[![Coverage Status](https://coveralls.io/repos/github/hjweide/pyastar2d/badge.svg?branch=master)](https://coveralls.io/github/hjweide/pyastar2d?branch=master) -[![PyPI version](https://badge.fury.io/py/pyastar2d.svg)](https://badge.fury.io/py/pyastar2d) -# PyAstar2D -This is a very simple C++ implementation of the A\* algorithm for pathfinding -on a two-dimensional grid. The solver itself is implemented in C++, but is -callable from Python. This combines the speed of C++ with the convenience of -Python. - -I have not done any formal benchmarking, but the solver finds the solution to a -1802 by 1802 maze in 0.29s and a 4008 by 4008 maze in 0.83s when running on my -nine-year-old Intel(R) Core(TM) i7-2630QM CPU @ 2.00GHz. See [Example -Results](#example-results) for more details. - -See `src/cpp/astar.cpp` for the core C++ implementation of the A\* shortest -path search algorithm, `src/pyastar2d/astar_wrapper.py` for the Python wrapper -and `examples/example.py` for example usage. - -When determining legal moves, 4-connectivity is the default, but it is possible -to set `allow_diagonal=True` for 8-connectivity. - -## Installation -Instructions for installing `pyastar2d` are given below. - -### From PyPI -The easiest way to install `pyastar2d` is directly from the Python package index: -``` -pip install pyastar2d -``` - -### From source -You can also install `pyastar2d` by cloning this repository and building it -yourself. If running on Linux or MacOS, simply run -```bash -pip install . -```` -from the root directory. If you are using Windows you may have to install Cython manually first: -```bash -pip install Cython -pip install . -``` -To check that everything worked, run the example: -```bash -python examples/example.py -``` - -### As a dependency -Include `pyastar2d` in your `requirements.txt` to install from `pypi`, or add -this line to `requirements.txt`: -``` -pyastar2d @ git+git://github.com/hjweide/pyastar2d.git@master#egg=pyastar2d -``` - -## Usage -A simple example is given below: -```python -import numpy as np -import pyastar2d -# The minimum cost must be 1 for the heuristic to be valid. -# The weights array must have np.float32 dtype to be compatible with the C++ code. -weights = np.array([[1, 3, 3, 3, 3], - [2, 1, 3, 3, 3], - [2, 2, 1, 3, 3], - [2, 2, 2, 1, 3], - [2, 2, 2, 2, 1]], dtype=np.float32) -# The start and goal coordinates are in matrix coordinates (i, j). -path = pyastar2d.astar_path(weights, (0, 0), (4, 4), allow_diagonal=True) -print(path) -# The path is returned as a numpy array of (i, j) coordinates. -array([[0, 0], - [1, 1], - [2, 2], - [3, 3], - [4, 4]]) -``` -Note that all grid points are represented as `(i, j)` coordinates. An example -of using `pyastar2d` to solve a maze is given in `examples/maze_solver.py`. - -## Example Results - -To test the implementation, I grabbed two nasty mazes from Wikipedia. They are -included in the ```mazes``` directory, but are originally from here: -[Small](https://upload.wikimedia.org/wikipedia/commons/c/cf/MAZE.png) and -[Large](https://upload.wikimedia.org/wikipedia/commons/3/32/MAZE_2000x2000_DFS.png). -I load the ```.png``` files as grayscale images, and set the white pixels to 1 -(open space) and the black pixels to `INF` (walls). - -To run the examples specify the input and output files using the `--input` and -`--output` flags. For example, the following commands will solve the small and -large mazes: -``` -python examples/maze_solver.py --input mazes/maze_small.png --output solns/maze_small.png -python examples/maze_solver.py --input mazes/maze_large.png --output solns/maze_large.png -``` - -### Small Maze (1802 x 1802): -```bash -time python examples/maze_solver.py --input mazes/maze_small.png --output solns/maze_small.png -Loaded maze of shape (1802, 1802) from mazes/maze_small.png -Found path of length 10032 in 0.292794s -Plotting path to solns/maze_small.png -Done - -real 0m1.214s -user 0m1.526s -sys 0m0.606s -``` -The solution found for the small maze is shown below: -Maze Small Solution - -### Large Maze (4002 x 4002): -```bash -time python examples/maze_solver.py --input mazes/maze_large.png --output solns/maze_large.png -Loaded maze of shape (4002, 4002) from mazes/maze_large.png -Found path of length 783737 in 0.829181s -Plotting path to solns/maze_large.png -Done - -real 0m29.385s -user 0m29.563s -sys 0m0.728s -``` -The solution found for the large maze is shown below: -Maze Large Solution - -## Motivation -I recently needed an implementation of the A* algorithm in Python to find the -shortest path between two points in a cost matrix representing an image. -Normally I would simply use [networkx](https://networkx.github.io/), but for -graphs with millions of nodes the overhead incurred to construct the graph can -be expensive. Considering that I was only interested in graphs that may be -represented as two-dimensional grids, I decided to implement it myself using -this special structure of the graph to make various optimizations. -Specifically, the graph is represented as a one-dimensional array because there -is no need to store the neighbors. Additionally, the lookup tables for -previously-explored nodes (their costs and paths) are also stored as -one-dimensional arrays. The implication of this is that checking the lookup -table can be done in O(1), at the cost of using O(n) memory. Alternatively, we -could store only the nodes we traverse in a hash table to reduce the memory -usage. Empirically I found that replacing the one-dimensional array with a -hash table (`std::unordered_map`) was about five times slower. - -## Tests -The default installation does not include the dependencies necessary to run the -tests. To install these, first run -```bash -pip install -r requirements-dev.txt -``` -before running -```bash -py.test -``` -The tests are fairly basic but cover some of the -more common pitfalls. Pull requests for more extensive tests are welcome. - -## References -1. [A\* search algorithm on Wikipedia](https://en.wikipedia.org/wiki/A*_search_algorithm#Pseudocode) -2. [Pathfinding with A* on Red Blob Games](http://www.redblobgames.com/pathfinding/a-star/introduction.html) diff --git a/tpl/pyastar2d/examples/example.py b/tpl/pyastar2d/examples/example.py deleted file mode 100644 index 53b8fa1..0000000 --- a/tpl/pyastar2d/examples/example.py +++ /dev/null @@ -1,19 +0,0 @@ -import numpy as np -import pyastar2d - -# The start and goal coordinates are in matrix coordinates (i, j). -start = (0, 0) -goal = (4, 4) - -# The minimum cost must be 1 for the heuristic to be valid. -weights = np.array( - [[1, 3, 3, 3, 3], [2, 1, 3, 3, 3], [2, 2, 1, 3, 3], [2, 2, 2, 1, 3], [2, 2, 2, 2, 1]], - dtype=np.float32, -) -print('Cost matrix:') -print(weights) -path = pyastar2d.astar_path(weights, start, goal, allow_diagonal=True) - -# The path is returned as a numpy array of (i, j) coordinates. -print(f'Shortest path from {start} to {goal} found:') -print(path) diff --git a/tpl/pyastar2d/examples/maze_solver.py b/tpl/pyastar2d/examples/maze_solver.py deleted file mode 100644 index cde508c..0000000 --- a/tpl/pyastar2d/examples/maze_solver.py +++ /dev/null @@ -1,75 +0,0 @@ -import argparse -import time -from os.path import basename, join - -import imageio -import numpy as np -import pyastar2d - - -def parse_args(): - parser = argparse.ArgumentParser('An example of using pyastar2d to find the solution to a maze') - parser.add_argument( - '--input', - type=str, - default='mazes/maze_small.png', - help='Path to the black-and-white image to be used as input.', - ) - parser.add_argument( - '--output', - type=str, - help='Path to where the output will be written', - ) - - args = parser.parse_args() - - if args.output is None: - args.output = join('solns', basename(args.input)) - - return args - - -def main(): - args = parse_args() - maze = imageio.imread(args.input) - - if maze is None: - print(f'No file found: {args.input}') - return - else: - print(f'Loaded maze of shape {maze.shape} from {args.input}') - - grid = maze.astype(np.float32) - grid[grid == 0] = np.inf - grid[grid == 255] = 1 - - assert grid.min() == 1, 'cost of moving must be at least 1' - - # start is the first white block in the top row - (start_j,) = np.where(grid[0, :] == 1) - start = np.array([0, start_j[0]]) - - # end is the first white block in the final column - (end_i,) = np.where(grid[:, -1] == 1) - end = np.array([end_i[0], grid.shape[0] - 1]) - - t0 = time.time() - # set allow_diagonal=True to enable 8-connectivity - path = pyastar2d.astar_path(grid, start, end, allow_diagonal=False) - dur = time.time() - t0 - - if path.shape[0] > 0: - print(f'Found path of length {path.shape[0]} in {dur:.6f}s') - maze = np.stack((maze, maze, maze), axis=2) - maze[path[:, 0], path[:, 1]] = (255, 0, 0) - - print(f'Plotting path to {args.output}') - imageio.imwrite(args.output, maze) - else: - print('No path found') - - print('Done') - - -if __name__ == '__main__': - main() diff --git a/tpl/pyastar2d/pyproject.toml b/tpl/pyastar2d/pyproject.toml deleted file mode 100644 index 9ee21c9..0000000 --- a/tpl/pyastar2d/pyproject.toml +++ /dev/null @@ -1,3 +0,0 @@ -[build-system] -requires = ["setuptools>=64", "wheel", "numpy"] -build-backend = "setuptools.build_meta" diff --git a/tpl/pyastar2d/requirements-dev.txt b/tpl/pyastar2d/requirements-dev.txt deleted file mode 100644 index a7f9fc1..0000000 --- a/tpl/pyastar2d/requirements-dev.txt +++ /dev/null @@ -1,4 +0,0 @@ -pytest -pytest-cov -pytest-pep8 -pyyaml diff --git a/tpl/pyastar2d/requirements.txt b/tpl/pyastar2d/requirements.txt deleted file mode 100644 index 3a353ed..0000000 --- a/tpl/pyastar2d/requirements.txt +++ /dev/null @@ -1,2 +0,0 @@ -imageio -numpy diff --git a/tpl/pyastar2d/setup.py b/tpl/pyastar2d/setup.py deleted file mode 100644 index ddfe4aa..0000000 --- a/tpl/pyastar2d/setup.py +++ /dev/null @@ -1,52 +0,0 @@ -import pathlib - -from setuptools import Extension, find_packages, setup - -# Use pathlib for paths -here = pathlib.Path(__file__).parent.resolve() - -# Read README and requirements -long_description = (here / 'README.md').read_text(encoding='utf-8') -install_requires = (here / 'requirements.txt').read_text().splitlines() - - -class get_numpy_include: - """Defer numpy import until it is actually installed.""" - - def __str__(self): - import numpy - - return numpy.get_include() - - -# Define the C++ extension -astar_module = Extension( - name='pyastar2d.astar', - sources=[ - 'src/cpp/astar.cpp', - 'src/cpp/experimental_heuristics.cpp', - ], - include_dirs=[ - 'src/cpp', - get_numpy_include(), - ], - extra_compile_args=['-shared'], - language='c++', -) - -# Define package metadata -setup( - name='pyastar2d', - version='1.0.7', - author='Hendrik Weideman', - author_email='hjweide@gmail.com', - description='A simple implementation of the A* algorithm for path-finding on a two-dimensional grid.', - long_description=long_description, - long_description_content_type='text/markdown', - url='https://github.com/hjweide/pyastar2d', - packages=find_packages(where='src', exclude=('tests',)), - package_dir={'': 'src'}, - install_requires=install_requires, - python_requires='>=3.7', - ext_modules=[astar_module], -) diff --git a/tpl/pyastar2d/src/cpp/astar.cpp b/tpl/pyastar2d/src/cpp/astar.cpp deleted file mode 100644 index 727b05a..0000000 --- a/tpl/pyastar2d/src/cpp/astar.cpp +++ /dev/null @@ -1,193 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include - - -const float INF = std::numeric_limits::infinity(); - -// represents a single pixel -class Node { - public: - int idx; // index in the flattened grid - float cost; // cost of traversing this pixel - int path_length; // the length of the path to reach this node - - Node(int i, float c, int path_length) : idx(i), cost(c), path_length(path_length) {} -}; - -// the top of the priority queue is the greatest element by default, -// but we want the smallest, so flip the sign -bool operator<(const Node &n1, const Node &n2) { - return n1.cost > n2.cost; -} - -// See for various grid heuristics: -// http://theory.stanford.edu/~amitp/GameProgramming/Heuristics.html#S7 -// L_\inf norm (diagonal distance) -inline float linf_norm(int i0, int j0, int i1, int j1) { - return std::max(std::abs(i0 - i1), std::abs(j0 - j1)); -} - -// L_1 norm (manhattan distance) -inline float l1_norm(int i0, int j0, int i1, int j1) { - return std::abs(i0 - i1) + std::abs(j0 - j1); -} - -// Tie breaker heuristic (distance to direct line between start and goal) -inline float tie_breaker(int i0, int j0, int is, int js, int ig, int jg) { - return 0.001 * abs((j0 - jg)*(is - ig) - (js - jg)*(i0 - ig)); -} - -// weights: flattened h x w grid of costs -// h, w: height and width of grid -// start, goal: index of start/goal in flattened grid -// diag_ok: if true, allows diagonal moves (8-conn.) -// paths (output): for each node, stores previous node in path -static PyObject *astar(PyObject *self, PyObject *args) { - const PyArrayObject* weights_object; - int h; - int w; - int start; - int goal; - int diag_ok; - int heuristic_override; - - if (!PyArg_ParseTuple( - args, "Oiiiiii", // i = int, O = object - &weights_object, - &h, &w, - &start, &goal, - &diag_ok, &heuristic_override - )) - return NULL; - - float* weights = (float*) weights_object->data; - int* paths = new int[h * w]; - int path_length = -1; - - Node start_node(start, 0., 1); - - float* costs = new float[h * w]; - for (int i = 0; i < h * w; ++i) - costs[i] = INF; - costs[start] = 0.; - - std::priority_queue nodes_to_visit; - nodes_to_visit.push(start_node); - - int* nbrs = new int[8]; - - int goal_i = goal / w; - int goal_j = goal % w; - int start_i = start / w; - int start_j = start % w; - - heuristic_ptr heuristic_func = select_heuristic(heuristic_override); - - while (!nodes_to_visit.empty()) { - // .top() doesn't actually remove the node - Node cur = nodes_to_visit.top(); - - if (cur.idx == goal) { - path_length = cur.path_length; - break; - } - - nodes_to_visit.pop(); - - int row = cur.idx / w; - int col = cur.idx % w; - // check bounds and find up to eight neighbors: top to bottom, left to right - nbrs[0] = (diag_ok && row > 0 && col > 0) ? cur.idx - w - 1 : -1; - nbrs[1] = (row > 0) ? cur.idx - w : -1; - nbrs[2] = (diag_ok && row > 0 && col + 1 < w) ? cur.idx - w + 1 : -1; - nbrs[3] = (col > 0) ? cur.idx - 1 : -1; - nbrs[4] = (col + 1 < w) ? cur.idx + 1 : -1; - nbrs[5] = (diag_ok && row + 1 < h && col > 0) ? cur.idx + w - 1 : -1; - nbrs[6] = (row + 1 < h) ? cur.idx + w : -1; - nbrs[7] = (diag_ok && row + 1 < h && col + 1 < w ) ? cur.idx + w + 1 : -1; - - float heuristic_cost, current_i, current_j; - for (int i = 0; i < 8; ++i) { - if (nbrs[i] >= 0) { - // Calculate the coordinates of the neighbor - current_i = nbrs[i] / w; - current_j = nbrs[i] % w; - - // Calculate the tie breaker heuristic - float tie_break = tie_breaker( - current_i, current_j, start_i, start_j, goal_i, goal_j); - - // Sum of the cost so far and the cost of this move - float new_cost = costs[cur.idx] + weights[nbrs[i]] + tie_break; - if (new_cost < costs[nbrs[i]]) { - // estimate the cost to the goal based on legal moves - // Get the heuristic method to use - if (heuristic_override == DEFAULT) { - if (diag_ok) { - heuristic_cost = linf_norm(current_i, current_j, goal_i, goal_j); - } else { - heuristic_cost = l1_norm(current_i, current_j, goal_i, goal_j); - } - } else { - heuristic_cost = heuristic_func( - current_i, current_j, goal_i, goal_j, start_i, start_j); - } - - // paths with lower expected cost are explored first - float priority = new_cost + heuristic_cost; - nodes_to_visit.push(Node(nbrs[i], priority, cur.path_length + 1)); - - costs[nbrs[i]] = new_cost; - paths[nbrs[i]] = cur.idx; - } - } - } - } - - PyObject *return_val; - if (path_length >= 0) { - npy_intp dims[2] = {path_length, 2}; - PyArrayObject* path = (PyArrayObject*) PyArray_SimpleNew(2, dims, NPY_INT32); - npy_int32 *iptr, *jptr; - int idx = goal; - for (npy_intp i = dims[0] - 1; i >= 0; --i) { - iptr = (npy_int32*) (path->data + i * path->strides[0]); - jptr = (npy_int32*) (path->data + i * path->strides[0] + path->strides[1]); - - *iptr = idx / w; - *jptr = idx % w; - - idx = paths[idx]; - } - - return_val = PyArray_Return(path); - } - else { - return_val = Py_BuildValue(""); // no soln --> return None - } - - delete[] costs; - delete[] nbrs; - delete[] paths; - - return return_val; -} - -static PyMethodDef astar_methods[] = { - {"astar", (PyCFunction)astar, METH_VARARGS, "astar"}, - {NULL, NULL, 0, NULL} -}; - -static struct PyModuleDef astar_module = { - PyModuleDef_HEAD_INIT,"astar", NULL, -1, astar_methods -}; - -PyMODINIT_FUNC PyInit_astar(void) { - import_array(); - return PyModule_Create(&astar_module); -} diff --git a/tpl/pyastar2d/src/cpp/experimental_heuristics.cpp b/tpl/pyastar2d/src/cpp/experimental_heuristics.cpp deleted file mode 100644 index c6a5a17..0000000 --- a/tpl/pyastar2d/src/cpp/experimental_heuristics.cpp +++ /dev/null @@ -1,42 +0,0 @@ -// Please note below heuristics are experimental and only for pretty lines. -// They may not take the shortest path and require additional cpu cycles. - -#include -#include -#include - - -heuristic_ptr select_heuristic(int h) { - switch (h) { - case ORTHOGONAL_X: - return orthogonal_x; - case ORTHOGONAL_Y: - return orthogonal_y; - default: - return NULL; - } -} - -// Orthogonal x (moves by x first, then half way by y) -float orthogonal_x(int i0, int j0, int i1, int j1, int i2, int j2) { - int di = std::abs(i0 - i1); - int dim = std::abs(i1 - i2); - int djm = std::abs(j1 - j2); - if (di > (dim * 0.5)) { - return di + djm; - } else { - return std::abs(j0 - j1); - } -} - -// Orthogonal y (moves by y first, then half way by x) -float orthogonal_y(int i0, int j0, int i1, int j1, int i2, int j2) { - int dj = std::abs(j0 - j1); - int djm = std::abs(j1 - j2); - int dim = std::abs(i1 - i2); - if (dj > (djm * 0.5)) { - return dj + dim; - } else { - return std::abs(i0 - i1); - } -} diff --git a/tpl/pyastar2d/src/cpp/experimental_heuristics.h b/tpl/pyastar2d/src/cpp/experimental_heuristics.h deleted file mode 100644 index 98c0562..0000000 --- a/tpl/pyastar2d/src/cpp/experimental_heuristics.h +++ /dev/null @@ -1,20 +0,0 @@ -// Please note below heuristics are experimental and only for pretty lines. -// They may not take the shortest path and require additional cpu cycles. - -#ifndef EXPERIMENTAL_HEURISTICS_H_ -#define EXPERIMENTAL_HEURISTICS_H_ - - -enum Heuristic { DEFAULT, ORTHOGONAL_X, ORTHOGONAL_Y }; - -typedef float (*heuristic_ptr)(int, int, int, int, int, int); - -heuristic_ptr select_heuristic(int); - -// Orthogonal x (moves by x first, then half way by y) -float orthogonal_x(int, int, int, int, int, int); - -// Orthogonal y (moves by y first, then half way by x) -float orthogonal_y(int, int, int, int, int, int); - -#endif diff --git a/tpl/pyastar2d/src/pyastar2d/__init__.py b/tpl/pyastar2d/src/pyastar2d/__init__.py deleted file mode 100644 index d0cd0a2..0000000 --- a/tpl/pyastar2d/src/pyastar2d/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from pyastar2d.astar_wrapper import Heuristic, astar_path - -__all__ = ['astar_path', 'Heuristic'] diff --git a/tpl/pyastar2d/src/pyastar2d/astar_wrapper.py b/tpl/pyastar2d/src/pyastar2d/astar_wrapper.py deleted file mode 100644 index b8a7f7d..0000000 --- a/tpl/pyastar2d/src/pyastar2d/astar_wrapper.py +++ /dev/null @@ -1,76 +0,0 @@ -import ctypes -from enum import IntEnum -from typing import Optional, Tuple - -import numpy as np -import pyastar2d.astar - -# Define array types -ndmat_f_type = np.ctypeslib.ndpointer(dtype=np.float32, ndim=1, flags='C_CONTIGUOUS') -ndmat_i2_type = np.ctypeslib.ndpointer(dtype=np.int32, ndim=2, flags='C_CONTIGUOUS') - -# Define input/output types -pyastar2d.astar.restype = ndmat_i2_type # Nx2 (i, j) coordinates or None -pyastar2d.astar.argtypes = [ - ndmat_f_type, # weights - ctypes.c_int, # height - ctypes.c_int, # width - ctypes.c_int, # start index in flattened grid - ctypes.c_int, # goal index in flattened grid - ctypes.c_bool, # allow diagonal - ctypes.c_int, # heuristic_override -] - - -class Heuristic(IntEnum): - """The supported heuristics.""" - - DEFAULT = 0 - ORTHOGONAL_X = 1 - ORTHOGONAL_Y = 2 - - -def astar_path( - weights: np.ndarray, - start: Tuple[int, int], - goal: Tuple[int, int], - allow_diagonal: bool = False, - heuristic_override: Heuristic = Heuristic.DEFAULT, -) -> Optional[np.ndarray]: - """ - Run astar algorithm on 2d weights. - - param np.ndarray weights: A grid of weights e.g. np.ones((10, 10), dtype=np.float32) - param Tuple[int, int] start: (i, j) - param Tuple[int, int] goal: (i, j) - param bool allow_diagonal: Whether to allow diagonal moves - param Heuristic heuristic_override: Override heuristic, see Heuristic(IntEnum) - - """ - assert ( - weights.dtype == np.float32 - ), f'weights must have np.float32 data type, but has {weights.dtype}' - # For the heuristic to be valid, each move must cost at least 1. - if weights.min(axis=None) < 1.0: - raise ValueError('Minimum cost to move must be 1, but got %f' % (weights.min(axis=None))) - # Ensure start is within bounds. - if start[0] < 0 or start[0] >= weights.shape[0] or start[1] < 0 or start[1] >= weights.shape[1]: - raise ValueError(f'Start of {start} lies outside grid.') - # Ensure goal is within bounds. - if goal[0] < 0 or goal[0] >= weights.shape[0] or goal[1] < 0 or goal[1] >= weights.shape[1]: - raise ValueError(f'Goal of {goal} lies outside grid.') - - height, width = weights.shape - start_idx = np.ravel_multi_index(start, (height, width)) - goal_idx = np.ravel_multi_index(goal, (height, width)) - - path = pyastar2d.astar.astar( - weights.flatten(), - height, - width, - start_idx, - goal_idx, - allow_diagonal, - int(heuristic_override), - ) - return path diff --git a/tpl/pyastar2d/tests/test_astar.py b/tpl/pyastar2d/tests/test_astar.py deleted file mode 100644 index 532a660..0000000 --- a/tpl/pyastar2d/tests/test_astar.py +++ /dev/null @@ -1,141 +0,0 @@ -import numpy as np -import pyastar2d -import pytest -from pyastar2d import Heuristic - - -def test_small(): - weights = np.array( - [[1, 3, 3, 3, 3], [2, 1, 3, 3, 3], [2, 2, 1, 3, 3], [2, 2, 2, 1, 3], [2, 2, 2, 2, 1]], - dtype=np.float32, - ) - # Run down the diagonal. - path = pyastar2d.astar_path(weights, (0, 0), (4, 4), allow_diagonal=True) - expected = np.array([[0, 0], [1, 1], [2, 2], [3, 3], [4, 4]]) - - assert np.all(path == expected) - - # Down, right, down, right, etc. - path = pyastar2d.astar_path(weights, (0, 0), (4, 4), allow_diagonal=False) - expected = np.array([[0, 0], [1, 0], [1, 1], [2, 1], [2, 2], [3, 2], [3, 3], [4, 3], [4, 4]]) - - assert np.all(path == expected) - - -def test_no_solution(): - # Vertical wall. - weights = np.ones((5, 5), dtype=np.float32) - weights[:, 2] = np.inf - - path = pyastar2d.astar_path(weights, (0, 0), (4, 4), allow_diagonal=True) - assert not path - - # Horizontal wall. - weights = np.ones((5, 5), dtype=np.float32) - weights[2, :] = np.inf - - path = pyastar2d.astar_path(weights, (0, 0), (4, 4), allow_diagonal=True) - assert not path - - -def test_match_reverse(): - # Might fail if there are multiple paths, but this should be rare. - h, w = 25, 25 - weights = (1.0 + 5.0 * np.random.random((h, w))).astype(np.float32) - - fwd = pyastar2d.astar_path(weights, (0, 0), (h - 1, w - 1)) - rev = pyastar2d.astar_path(weights, (h - 1, w - 1), (0, 0)) - - assert np.all(fwd[::-1] == rev) - - fwd = pyastar2d.astar_path(weights, (0, 0), (h - 1, w - 1), allow_diagonal=True) - rev = pyastar2d.astar_path(weights, (h - 1, w - 1), (0, 0), allow_diagonal=True) - - assert np.all(fwd[::-1] == rev) - - -def test_narrow(): - # Column weights. - weights = np.ones((5, 1), dtype=np.float32) - path = pyastar2d.astar_path(weights, (0, 0), (4, 0)) - - expected = np.array([[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]]) - - assert np.all(path == expected) - - # Row weights. - weights = np.ones((1, 5), dtype=np.float32) - path = pyastar2d.astar_path(weights, (0, 0), (0, 4)) - - expected = np.array([[0, 0], [0, 1], [0, 2], [0, 3], [0, 4]]) - - assert np.all(path == expected) - - -def test_bad_heuristic(): - # For valid heuristics, the cost to move must be at least 1. - weights = (1.0 + 5.0 * np.random.random((10, 10))).astype(np.float32) - # An element smaller than 1 should raise a ValueError. - bad_cost = np.random.random() / 2.0 - weights[4, 4] = bad_cost - - with pytest.raises(ValueError) as exc: - pyastar2d.astar_path(weights, (0, 0), (9, 9)) - assert '.f' % bad_cost in exc.value.args[0] - - -def test_invalid_start_and_goal(): - weights = (1.0 + 5.0 * np.random.random((10, 10))).astype(np.float32) - # Test bad start indices. - with pytest.raises(ValueError) as exc: - pyastar2d.astar_path(weights, (-1, 0), (9, 9)) - assert '-1' in exc.value.args[0] - with pytest.raises(ValueError) as exc: - pyastar2d.astar_path(weights, (10, 0), (9, 9)) - assert '10' in exc.value.args[0] - with pytest.raises(ValueError) as exc: - pyastar2d.astar_path(weights, (0, -1), (9, 9)) - assert '-1' in exc.value.args[0] - with pytest.raises(ValueError) as exc: - pyastar2d.astar_path(weights, (0, 10), (9, 9)) - assert '10' in exc.value.args[0] - # Test bad goal indices. - with pytest.raises(ValueError) as exc: - pyastar2d.astar_path(weights, (0, 0), (-1, 9)) - assert '-1' in exc.value.args[0] - with pytest.raises(ValueError) as exc: - pyastar2d.astar_path(weights, (0, 0), (10, 9)) - assert '10' in exc.value.args[0] - with pytest.raises(ValueError) as exc: - pyastar2d.astar_path(weights, (0, 0), (0, -1)) - assert '-1' in exc.value.args[0] - with pytest.raises(ValueError) as exc: - pyastar2d.astar_path(weights, (0, 0), (0, 10)) - assert '10' in exc.value.args[0] - - -def test_bad_weights_dtype(): - weights = np.array([[1, 2, 3], [1, 2, 3], [1, 2, 3]], dtype=np.float64) - with pytest.raises(AssertionError) as exc: - pyastar2d.astar_path(weights, (0, 0), (2, 2)) - assert 'float64' in exc.value.args[0] - - -def test_orthogonal_x(): - weights = np.ones((5, 5), dtype=np.float32) - path = pyastar2d.astar_path( - weights, (0, 0), (4, 4), allow_diagonal=False, heuristic_override=Heuristic.ORTHOGONAL_X - ) - expected = np.array([[0, 0], [1, 0], [2, 0], [2, 1], [2, 2], [2, 3], [3, 3], [3, 4], [4, 4]]) - - assert np.all(path == expected) - - -def test_orthogonal_y(): - weights = np.ones((5, 5), dtype=np.float32) - path = pyastar2d.astar_path( - weights, (0, 0), (4, 4), allow_diagonal=False, heuristic_override=Heuristic.ORTHOGONAL_Y - ) - expected = np.array([[0, 0], [0, 1], [0, 2], [1, 2], [2, 2], [3, 2], [3, 3], [4, 3], [4, 4]]) - - assert np.all(path == expected)