diff --git a/README.md b/README.md index c4ab210..1bba96e 100644 --- a/README.md +++ b/README.md @@ -29,9 +29,14 @@ Visit the [project webpage](http://rpg.ifi.uzh.ch/unsupervised_detection.html) f conda case; ```bash conda env create -f environment.yml -bash ./scripts/test_DAVIS2016_raw.sh +conda activate contextual-information-separation +python ./scripts/test_DAVIS2016_raw.py ``` -you can even run inference for no annotated video. +or for FBMS59 data, +```bash +python ./scripts/test_FBMS59_raw.py +``` +you can even run inference for non annotated video. ```bash bash ./scripts/test_video.sh ``` @@ -58,6 +63,14 @@ The datasets can be used without any pre-processing. ### Downloads +You can download all necessary files just running +```bash +python ./scripts/test_DAVIS2016_raw.py +or +python ./scripts/test_FBMS59_raw.py +``` + +But you can also download manually by following below steps. We generate optical flows with a tensorflow implementation of PWCNet, which is an adapted version of [this repository](https://github.com/philferriere/tfoptflow). To compute flows, please download the model checkpoint of PWCNet we used for our experiments, available at [this link](https://drive.google.com/open?id=1gtGx_6MjUQC5lZpl6-Ia718Y_0pvcYou). @@ -92,7 +105,7 @@ You can test a trained model with the function [test\_generator.py](./test_gener An example is provided for the DAVIS 2016 dataset in the [scripts](./scripts) folder. To run it, edit the file [test\_DAVIS2016\_raw.sh](./scripts/test_DAVIS2016_raw.sh) with the paths to the dataset, the optical flow and the model checkpoint. After that, you can test the model with the following command: ```bash -bash ./scripts/test_DAVIS2016_raw.sh +python ./scripts/test_DAVIS2016_raw.py ``` #### Testing for your own video diff --git a/environment.yml b/environment.yml index dddba66..34c803c 100644 --- a/environment.yml +++ b/environment.yml @@ -1,12 +1,238 @@ name: contextual-information-separation channels: - conda-forge + - defaults dependencies: - - python=3 - - tensorflow-gpu=1.13.1 - - opencv - - cudatoolkit=10.1 - - python-gflags - - keras=2.2.4 - - gdown - - pillow \ No newline at end of file + - _libgcc_mutex=0.1=conda_forge + - _openmp_mutex=4.5=2_gnu + - _tflow_select=2.1.0=gpu + - absl-py=2.1.0=pyhd8ed1ab_0 + - alsa-lib=1.2.8=h166bdaf_0 + - aom=3.5.0=h27087fc_0 + - astor=0.8.1=pyh9f0ad1d_0 + - attr=2.5.1=h166bdaf_1 + - beautifulsoup4=4.12.3=pyha770c72_0 + - binutils_impl_linux-64=2.39=he00db2b_1 + - binutils_linux-64=2.39=h5fc0e48_13 + - blas=1.1=openblas + - brotli-python=1.0.9=py37hd23a5d3_7 + - bzip2=1.0.8=h4bc722e_7 + - c-ares=1.34.1=heb4867d_0 + - ca-certificates=2024.11.26=h06a4308_0 + - cached-property=1.5.2=hd8ed1ab_1 + - cached_property=1.5.2=pyha770c72_1 + - cairo=1.16.0=ha61ee94_1012 + - certifi=2024.8.30=pyhd8ed1ab_0 + - charset-normalizer=3.4.0=pyhd8ed1ab_0 + - colorama=0.4.6=pyhd8ed1ab_0 + - cudatoolkit=10.1.243=h6d9799a_13 + - dbus=1.13.6=h5008d03_3 + - expat=2.6.3=h5888daf_0 + - ffmpeg=5.1.2=gpl_h8dda1f0_106 + - fftw=3.3.10=nompi_hf1063bd_110 + - filelock=3.16.1=pyhd8ed1ab_0 + - font-ttf-dejavu-sans-mono=2.37=hab24e00_0 + - font-ttf-inconsolata=3.000=h77eed37_0 + - font-ttf-source-code-pro=2.038=h77eed37_0 + - font-ttf-ubuntu=0.83=h77eed37_3 + - fontconfig=2.14.2=h14ed4e7_0 + - fonts-conda-ecosystem=1=0 + - fonts-conda-forge=1=0 + - freeglut=3.2.2=h9c3ff4c_1 + - freetype=2.12.1=h267a509_2 + - gast=0.5.5=pyhd8ed1ab_0 + - gcc_impl_linux-64=10.4.0=h5231bdf_19 + - gcc_linux-64=10.4.0=h9215b83_13 + - gdown=5.0.1=pyhd8ed1ab_0 + - gettext=0.22.5=he02047a_3 + - gettext-tools=0.22.5=he02047a_3 + - gflags=2.2.2=h6a678d5_1 + - glib=2.82.1=h2ff4ddf_0 + - glib-tools=2.82.1=h4833e2c_0 + - gmp=6.3.0=hac33072_2 + - gnutls=3.7.9=hb077bed_0 + - graphite2=1.3.13=h59595ed_1003 + - grpc-cpp=1.48.1=h30feacc_0 + - grpcio=1.48.1=py37hd557365_0 + - gst-plugins-base=1.21.3=h4243ec0_1 + - gstreamer=1.21.3=h25f0c4b_1 + - gstreamer-orc=0.4.40=hb9d3cd8_0 + - gxx_impl_linux-64=10.4.0=h5231bdf_19 + - gxx_linux-64=10.4.0=h6e491c6_13 + - h5py=3.7.0=nompi_py37hf1ce037_101 + - harfbuzz=5.3.0=h418a68e_0 + - hdf5=1.12.2=nompi_h4df4325_101 + - icu=70.1=h27087fc_0 + - idna=3.10=pyhd8ed1ab_0 + - importlib-metadata=4.11.4=py37h89c1867_0 + - jack=1.9.22=h11f4161_0 + - jasper=2.0.33=h0ff4b12_1 + - jpeg=9e=h0b41bf4_3 + - keras=2.2.4=py37_1 + - keras-applications=1.0.8=py_1 + - keras-preprocessing=1.1.2=pyhd8ed1ab_0 + - kernel-headers_linux-64=3.10.0=he073ed8_17 + - keyutils=1.6.1=h166bdaf_0 + - krb5=1.20.1=h81ceb04_0 + - lame=3.100=h166bdaf_1003 + - lcms2=2.14=h6ed2654_0 + - ld_impl_linux-64=2.39=hcc3a1bd_1 + - lerc=4.0.0=h27087fc_0 + - libabseil=20220623.0=cxx17_h05df665_6 + - libaec=1.1.3=h59595ed_0 + - libasprintf=0.22.5=he8f35ee_3 + - libasprintf-devel=0.22.5=he8f35ee_3 + - libblas=3.9.0=24_linux64_openblas + - libcap=2.67=he9d0100_0 + - libcblas=3.9.0=24_linux64_openblas + - libclang=15.0.7=default_h127d8a8_5 + - libclang13=15.0.7=default_h5d6823c_5 + - libcups=2.3.3=h36d4200_3 + - libcurl=8.1.2=h409715c_0 + - libdb=6.2.32=h9c3ff4c_0 + - libdeflate=1.14=h166bdaf_0 + - libdrm=2.4.123=hb9d3cd8_0 + - libedit=3.1.20191231=he28a2e2_2 + - libev=4.33=hd590300_2 + - libevent=2.1.10=h28343ad_4 + - libexpat=2.6.3=h5888daf_0 + - libffi=3.4.2=h7f98852_5 + - libflac=1.4.3=h59595ed_0 + - libgcc=14.1.0=h77fa898_1 + - libgcc-devel_linux-64=10.4.0=hd38fd1e_19 + - libgcc-ng=14.1.0=h69a702a_1 + - libgcrypt=1.11.0=h4ab18f5_1 + - libgettextpo=0.22.5=he02047a_3 + - libgettextpo-devel=0.22.5=he02047a_3 + - libgfortran=14.1.0=h69a702a_1 + - libgfortran-ng=14.1.0=h69a702a_1 + - libgfortran5=14.1.0=hc5f4f2c_1 + - libglib=2.82.1=h2ff4ddf_0 + - libglu=9.0.0=he1b5a44_1001 + - libgomp=14.1.0=h77fa898_1 + - libgpg-error=1.50=h4f305b6_0 + - libgpuarray=0.7.6=h7f98852_1003 + - libiconv=1.17=hd590300_2 + - libidn2=2.3.7=hd590300_0 + - liblapack=3.9.0=24_linux64_openblas + - liblapacke=3.9.0=24_linux64_openblas + - libllvm15=15.0.7=hadd5161_1 + - libnghttp2=1.58.0=h47da74e_0 + - libnsl=2.0.1=hd590300_0 + - libogg=1.3.5=h4ab18f5_0 + - libopenblas=0.3.27=pthreads_hac2b453_1 + - libopencv=4.6.0=py37h7b66c90_5 + - libopus=1.3.1=h7f98852_1 + - libpciaccess=0.18=hd590300_0 + - libpng=1.6.44=hadc24fc_0 + - libpq=15.3=hbcd7760_1 + - libprotobuf=3.21.8=h6239696_0 + - libsanitizer=10.4.0=h5246dfb_19 + - libsndfile=1.2.2=hc60ed4a_1 + - libsqlite=3.46.1=hadc24fc_0 + - libssh2=1.11.0=h0841786_0 + - libstdcxx=14.1.0=hc0a3c3a_1 + - libstdcxx-devel_linux-64=10.4.0=hd38fd1e_19 + - libstdcxx-ng=14.1.0=h4852527_1 + - libsystemd0=253=h8c4010b_1 + - libtasn1=4.19.0=h166bdaf_0 + - libtiff=4.4.0=h82bc61c_5 + - libtool=2.4.7=he02047a_1 + - libudev1=253=h0b41bf4_1 + - libunistring=0.9.10=h7f98852_0 + - libuuid=2.38.1=h0b41bf4_0 + - libva=2.18.0=h0b41bf4_0 + - libvorbis=1.3.7=h9c3ff4c_0 + - libvpx=1.11.0=h9c3ff4c_3 + - libwebp-base=1.4.0=hd590300_0 + - libxcb=1.13=h7f98852_1004 + - libxkbcommon=1.5.0=h79f4944_1 + - libxml2=2.10.3=hca2bb57_4 + - libzlib=1.3.1=hb9d3cd8_2 + - lz4-c=1.9.4=hcb278e6_0 + - mako=1.3.5=pyhd8ed1ab_0 + - markdown=3.6=pyhd8ed1ab_0 + - markupsafe=2.1.1=py37h540881e_1 + - mock=5.1.0=pyhd8ed1ab_0 + - mpg123=1.32.6=h59595ed_0 + - mysql-common=8.0.33=hf1915f5_6 + - mysql-libs=8.0.33=hca2cd23_6 + - ncurses=6.5=he02047a_1 + - nettle=3.9.1=h7ab15ed_0 + - nspr=4.35=h27087fc_0 + - nss=3.105=hd34e28f_0 + - numpy=1.21.6=py37h976b520_0 + - openblas=0.3.27=pthreads_h9eca1d5_1 + - opencv=4.6.0=py37h89c1867_5 + - openh264=2.3.1=hcb278e6_2 + - openjpeg=2.5.0=h7d73246_1 + - openssl=3.1.7=hb9d3cd8_0 + - p11-kit=0.24.1=hc5aa10d_0 + - packaging=23.2=pyhd8ed1ab_0 + - pcre2=10.44=hba22ea6_2 + - pillow=9.2.0=py37h850a105_2 + - pip=24.0=pyhd8ed1ab_0 + - pixman=0.43.2=h59595ed_0 + - protobuf=4.21.8=py37hd23a5d3_0 + - pthread-stubs=0.4=hb9d3cd8_1002 + - pulseaudio=16.1=hcb278e6_3 + - pulseaudio-client=16.1=h5195f5e_3 + - pulseaudio-daemon=16.1=ha8d29e2_3 + - py-opencv=4.6.0=py37hf05f0b3_5 + - pygpu=0.7.6=py37hb1e94ed_1003 + - pysocks=1.7.1=py37h89c1867_5 + - python=3.7.12=hf930737_100_cpython + - python-gflags=3.1.2=py_0 + - python_abi=3.7=4_cp37m + - pyyaml=6.0=py37h540881e_4 + - qt-main=5.15.6=hf6cd601_5 + - re2=2022.06.01=h27087fc_1 + - readline=8.2=h8228510_1 + - requests=2.32.2=pyhd8ed1ab_0 + - scipy=1.7.3=py37hf838250_2 + - setuptools=69.0.3=pyhd8ed1ab_0 + - six=1.16.0=pyh6c4a22f_0 + - soupsieve=2.3.2.post1=pyhd8ed1ab_0 + - sqlite=3.46.1=h9eae976_0 + - svt-av1=1.4.1=hcb278e6_0 + - sysroot_linux-64=2.17=h4a8ded7_17 + - tensorboard=1.13.1=py37_0 + - tensorflow=1.13.1=py37_0 + - tensorflow-estimator=1.13.0=py_0 + - tensorflow-gpu=1.13.1=h0d30ee6_0 + - termcolor=2.3.0=pyhd8ed1ab_0 + - theano=1.0.5=py37hd23a5d3_3 + - tk=8.6.13=noxft_h4845f30_101 + - tqdm=4.66.5=pyhd8ed1ab_0 + - typing_extensions=4.7.1=pyha770c72_0 + - tzdata=2024b=hc8b5060_0 + - urllib3=2.2.1=pyhd8ed1ab_0 + - werkzeug=2.2.3=pyhd8ed1ab_0 + - wheel=0.42.0=pyhd8ed1ab_0 + - x264=1!164.3095=h166bdaf_2 + - x265=3.5=h924138e_3 + - xcb-util=0.4.0=h516909a_0 + - xcb-util-image=0.4.0=h166bdaf_0 + - xcb-util-keysyms=0.4.0=h516909a_0 + - xcb-util-renderutil=0.3.9=h166bdaf_0 + - xcb-util-wm=0.4.1=h516909a_0 + - xkeyboard-config=2.38=h0b41bf4_0 + - xorg-fixesproto=5.0=hb9d3cd8_1003 + - xorg-inputproto=2.3.2=hb9d3cd8_1003 + - xorg-kbproto=1.0.7=hb9d3cd8_1003 + - xorg-libice=1.1.1=hb9d3cd8_1 + - xorg-libsm=1.2.4=he73a12e_1 + - xorg-libx11=1.8.4=h0b41bf4_0 + - xorg-libxau=1.0.11=hb9d3cd8_1 + - xorg-libxdmcp=1.1.5=hb9d3cd8_0 + - xorg-libxext=1.3.4=h0b41bf4_2 + - xorg-libxfixes=5.0.3=h7f98852_1004 + - xorg-libxi=1.7.10=h7f98852_0 + - xorg-libxrender=0.9.10=h7f98852_1003 + - xorg-renderproto=0.11.1=hb9d3cd8_1003 + - xorg-xextproto=7.3.0=hb9d3cd8_1004 + - xorg-xproto=7.0.31=hb9d3cd8_1008 + - xz=5.2.6=h166bdaf_0 + - yaml=0.2.5=h7f98852_2 + - zipp=3.15.0=pyhd8ed1ab_0 + - zstd=1.5.6=ha6fb4c9_0 diff --git a/scripts/download_util.py b/scripts/download_util.py new file mode 100644 index 0000000..430dd86 --- /dev/null +++ b/scripts/download_util.py @@ -0,0 +1,283 @@ +import os +import requests +import zipfile +import subprocess +import logging +from tqdm import tqdm +import time + +logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s") + + +def download_file(url, destination_path, retries=3, timeout=60): + """_summary_ + Downloads a file from a URL to a destination path with progress bar, retries, and timeout. + + Args: + url (str): URL to download from. + destination_path (str): full path of the destination file name. + retries (int, optional): Number of retries on failure. Defaults to 3. + timeout (int, optional): Timeout in seconds for the request. Defaults to 60. + """ + for attempt in range(retries): + logging.info(f"Attempt {attempt + 1} of {retries} to download {url}") + try: + response = requests.get(url, stream=True, timeout=timeout) + response.raise_for_status() # raise an exception for HTTP errors + total_size = int(response.headers.get("content-length", 0)) + block_size = 1024 # 1 Kibibyte + + with open(destination_path, "wb") as file: + with tqdm( + desc="Downloading file", + total=total_size, + unit="iB", + unit_scale=True, + unit_divisor=block_size, + ) as bar: + for data in response.iter_content(block_size): + size = file.write(data) + bar.update(size) + + # Check size after download + actual_size = ( + os.path.getsize(destination_path) + if os.path.exists(destination_path) + else 0 + ) + if total_size != 0 and actual_size != total_size: + logging.warning( + f"Downloaded file size mismatch for {destination_path}: Server reported {total_size} bytes, but got {actual_size} bytes." + "Processing anyway as download completed." + ) + elif total_size == 0 and actual_size == 0: + logging.error(f"Downloaded file is empty: {destination_path}.") + if os.path.exists(destination_path): + os.remove(destination_path) + if attempt < retries - 1: + logging.info(f"Retrying download for {url} in 5 seconds...") + time.sleep(5) + continue + else: + logging.error( + f"Failed to download {url} after {retries} attempts. Giving up." + ) + return False + + logging.info(f"Downloaded successfully from {url} to {destination_path}") + return True + + except requests.exceptions.Timeout: + logging.error( + f"Timeout ({timeout}s) occurred while downloading {url} on attempt {attempt + 1}." + ) + if os.path.exists(destination_path): + os.remove(destination_path) + if attempt < retries - 1: + logging.info(f"Retrying download for {url} in 5 seconds...") + time.sleep(5) + else: + logging.error( + f"Failed to download {url} after {retries} attempts. Giving up." + ) + return False + except requests.exceptions.RequestException as e: + logging.error(f"Failed to download {url} on attempt {attempt + 1}: {e}") + if os.path.exists(destination_path): + os.remove(destination_path) + if attempt < retries - 1: + logging.info(f"Retrying download for {url} in 5 seconds...") + time.sleep(5) + else: + logging.error( + f"Failed to download {url} after {retries} attempts. Giving up." + ) + return False + except Exception as e: + logging.error(f"An unknown error occurred while downloading {url}: {e}") + if os.path.exists(destination_path): + os.remove(destination_path) + return False + + return False # If all attempts fail + + +def extract_zip(zip_path, extract_to_dir): + """_summary_ + Extracts a zip file to a specified directory. + + Args: + zip_path (_type_): _description_ + extract_to_dir (_type_): _description_ + """ + if not os.path.exists(zip_path): + logging.error(f"Zip file {zip_path} does not exist.") + return False + + try: + logging.info(f"Extracting {zip_path} to {extract_to_dir}") + with zipfile.ZipFile(zip_path, "r") as zip_ref: + zip_ref.extractall(extract_to_dir) + logging.info(f"Successfully extracted {zip_path} to {extract_to_dir}") + return True + + except zipfile.BadZipFile as e: + logging.error(f"The file {zip_path} is not a zip or currupted: {e}") + # Don't remove the zip file here, as user might want to retry. + return False + except Exception as e: + logging.error(f"An unkown error occurred while extracting {zip_path}: {e}") + # Don't remove the zip file here, as user might want to retry. + return False + + +def run_gdown(folder_url, destination_dir): + """_summary_ + Downloads a folder from Google Drive using gdown. + + Args: + folder_url (_type_): _description_ + destination_dir (_type_): _description_. destination full path directory name. + """ + try: + logging.info(f"Downloading {folder_url} to {destination_dir}") + os.makedirs(destination_dir, exist_ok=True) + subprocess.run( + ["gdown", "--folder", folder_url, "-O", destination_dir], check=True + ) + logging.info(f"Successfully downloaded {folder_url} to {destination_dir}") + return True + except subprocess.CalledProcessError as e: + logging.error(f"Failed to download {folder_url}: {e}") + return False + except Exception as e: + logging.error(f"An unknown error occurred while downloading {folder_url}: {e}") + return False + + +def ensure_dataset( + dataset_name, + download_urls, + destination_dir, +): + """ + Checks for dataset, downloads and extracts if missing. + + Args: + dataset_name (str): Name of the dataset. (e.g. "FBMS") + download_urls (str): URLs to download the dataset zip file. (e.g. ["https://example.com/dataset.zip","https://example.com/dataset2.zip"]) + destination_dir (str): Directory to save the dataset. (e.g. "/home/user/downloads"). So we assumed the zip file has top level directory (e.g. Trainset/, Testser/), and the unzip result will not be mixed if we specify the same destination. + """ + zip_extracted_path = os.path.join(destination_dir, dataset_name) + if os.path.exists(zip_extracted_path): + logging.info( + f"Dataset '{dataset_name}' already exists at {destination_dir}. Skipping download." + ) + return True + + logging.info( + f"Dataset '{dataset_name}' not found at {destination_dir}. Downloading..." + ) + for download_url in download_urls: + os.makedirs(destination_dir, exist_ok=True) + zip_basefname = dataset_name + zip_filepath = os.path.join(destination_dir, f"{zip_basefname}.zip") + + if not download_file(download_url, zip_filepath): + logging.error( + f"Failed to download {dataset_name} zip file from {download_url}." + ) + return False # Stop if download fails + + if not extract_zip(zip_filepath, zip_extracted_path): + logging.error(f"Failed to extract {zip_filepath} to {zip_extracted_path}.") + return False # Stop if extraction fails + + # Check if the extracted folder exists + if not os.path.exists(zip_extracted_path): + logging.error( + f"Expected extracted folder '{zip_extracted_path}' not found after unzipping." + ) + return False + + # Cleanup zip file + try: + logging.info(f"Cleaning up {zip_filepath}") + os.remove(zip_filepath) + except OSError as e: + logging.warning(f"Could not remove zip file {zip_filepath}. Error: {e}") + + logging.info(f"Successfully prepared dataset '{dataset_name}'.") + return True + + +def ensure_pwc_checkpoint(gdown_folder_url, pwc_ckpt_path): + """Checks for PWCNet checkpoint files, downloads via gdown if missing.""" + # Check for one of the expected files (adjust extensions if needed) + pwc_dir = os.path.dirname(pwc_ckpt_path) + if os.path.exists(pwc_ckpt_path): + logging.info(f"PWCNet checkpoint found at {pwc_dir}. Skipping download.") + return True + + logging.info("PWCNet checkpoint not found. Attempting download via gdown...") + # gdown downloads the *contents* of the folder into the target directory + if not run_gdown( + gdown_folder_url, pwc_dir, description="Downloading PWCNet checkpoint" + ): + return False + + # Verify again after download attempt + if not os.path.exists(pwc_ckpt_path): + logging.error( + f"PWCNet checkpoint file {pwc_ckpt_path} still not found after gdown attempt." + ) + return False + + logging.info("Successfully prepared PWCNet checkpoint.") + return True + + +def ensure_model_checkpoint(download_url, ckpt_path, zip_path=None): + """ + Checks for the specific model checkpoint, downloads and extracts parent zip if missing. + Args: + download_url (str): URL to download the zip file containing the model checkpoint. + ckpt_path (str): Path to the specific model checkpoint file. + zip_path (str): Path to the already downloaded zip file containing the model checkpoint. otherwise you don't need to set this.""" + if os.path.exists(ckpt_path): + logging.info(f"Model checkpoint found: {ckpt_path}. Skipping download.") + return True + + logging.info(f"Model checkpoint {ckpt_path} not found.") + + # Check if the zip exists first + if not os.path.exists(zip_path): + logging.info(f"Checkpoint archive {zip_path} not found. Downloading...") + if not download_file(download_url, zip_path): + return False # Stop if download fails + else: + logging.info(f"Checkpoint archive {zip_path} found.") + + # Extract the archive (even if it existed, maybe extraction failed before) + extract_target_dir = os.path.dirname(ckpt_path) + if not extract_zip(zip_path, extract_target_dir): + return False + + # Verify the specific checkpoint exists after extraction + if not os.path.exists(ckpt_path): + logging.error( + f"Model checkpoint {ckpt_path} still not found after extracting {zip_path}." + ) + logging.error( + "Please check the contents of the zip file and the expected path." + ) + return False + + # Clean up the zip file after successful extraction and verification + try: + os.remove(zip_path) + except OSError as e: + logging.warning(f"Could not remove checkpoint zip file {zip_path}. Error: {e}") + + logging.info(f"Successfully prepared model checkpoint {ckpt_path}.") + return True diff --git a/scripts/test_DAVIS2016_raw.py b/scripts/test_DAVIS2016_raw.py new file mode 100644 index 0000000..da7380e --- /dev/null +++ b/scripts/test_DAVIS2016_raw.py @@ -0,0 +1,116 @@ +import os +import subprocess +import logging +import download_util + + +def main(): + # --- Fixed Parameters --- + dataset_name = "DAVIS2016" + dataset_download_urls = [ + "https://graphics.ethz.ch/Downloads/Data/Davis/DAVIS-data.zip", + ] + model_ckpt_basedir = ( + "davis_best_model" # this is the name under "unsupervised_detection_models" + ) + rootdidr_name = ( + dataset_name + "/DAVIS" + ) # This is neccesary becuase DAVIS2016 zip top directory is DAVIS and unsupervised training needs root_dir as under DAVIS directory structure. + + LOG_LEVEL = logging.INFO + TEST_CROP = 0.9 # FBMS default + TEST_TEMPORAL_SHIFT = 1 + GENERATE_VISUALIZATION = True + # --- End Fixed Parameters --- + + logging.basicConfig(level=LOG_LEVEL, format="[%(levelname)s] %(message)s") + + # --- Define Paths and URLs for FBMS-59 --- + script_dir = os.path.dirname(os.path.realpath(__file__)) + base_dir = os.path.abspath( + os.path.join(script_dir, "..") + ) # Go up one level from scripts/ + download_dir = os.path.join(base_dir, "download") + results_dir = os.path.join(base_dir, "results", dataset_name) + + # Model Checkpoint + model_ckpt_zip_url = "https://rpg.ifi.uzh.ch/data/unsupervised_detection_models.zip" + model_ckpt_base = os.path.join( + download_dir, + "unsupervised_detection_models", + model_ckpt_basedir, + "model.best", + ) # this will be used as the argument of test_generator.py + model_ckpt_path = ( + model_ckpt_base + ".data-00000-of-00001" + ) # actual checkpoint path. + + # PWCNet Checkpoint + pwc_gdown_folder_url = ( + "https://drive.google.com/drive/folders/1gtGx_6MjUQC5lZpl6-Ia718Y_0pvcYou" + ) + pwc_ckpt_path = os.path.join( + download_dir, + "pwcnet-lg-6-2-multisteps-chairsthingsmix", + "pwcnet.ckpt-595000.data-00000-of-00001", + ) + + # --- Ensure Prerequisites --- + logging.info(f"--- Checking Prerequisites for {dataset_name} ---") + os.makedirs(download_dir, exist_ok=True) + os.makedirs(results_dir, exist_ok=True) # Ensure results dir exists + + # 1. Dataset + if not download_util.ensure_dataset( + dataset_name, + dataset_download_urls, + download_dir, + ): + logging.error(f"Failed to prepare dataset {dataset_name}. Exiting.") + exit(1) + + # 2. Model Checkpoint + if not download_util.ensure_model_checkpoint(model_ckpt_zip_url, model_ckpt_path): + logging.error(f"Failed to prepare model checkpoint {model_ckpt_path}. Exiting.") + exit(1) + + # 3. PWCNet Checkpoint + if not download_util.ensure_pwc_checkpoint(pwc_gdown_folder_url, pwc_ckpt_path): + logging.error(f"Failed to prepare PWCNet checkpoint {pwc_ckpt_path}. Exiting.") + exit(1) + + logging.info("--- Prerequisites Met ---") + + # --- Run Test Generator --- + logging.info("Starting test generation...") + test_command = [ + "python3", + os.path.join(base_dir, "test_generator.py"), # Path to test_generator.py + f"--dataset={dataset_name}", + f"--ckpt_file={model_ckpt_base}", + f"--flow_ckpt={pwc_ckpt_path}", + f"--test_crop={TEST_CROP}", + f"--test_temporal_shift={TEST_TEMPORAL_SHIFT}", + f"--root_dir={download_dir}/{rootdidr_name}", + f"--test_save_dir={results_dir}", + ] + if GENERATE_VISUALIZATION: + test_command.append("--generate_visualization") + + logging.info(f"Running command: {' '.join(test_command)}") + try: + # Run the command from the script's directory + subprocess.run(test_command, check=True, cwd=script_dir) + logging.info("Test generation finished successfully.") + except subprocess.CalledProcessError as e: + logging.error(f"Test generation failed with error code {e.returncode}.") + exit(1) + except FileNotFoundError: + logging.error( + f"Error: test_generator.py not found in {script_dir}. Make sure it exists." + ) + exit(1) + + +if __name__ == "__main__": + main() diff --git a/scripts/test_DAVIS2016_raw.sh b/scripts/test_DAVIS2016_raw.sh deleted file mode 100755 index 2430be2..0000000 --- a/scripts/test_DAVIS2016_raw.sh +++ /dev/null @@ -1,51 +0,0 @@ -#!/bin/bash -# -# Script to compute raw results (before post-processing) -### - -SCRIPT_DIR=$(dirname "$(realpath "$0")") - -# parameters -DOWNLOAD_DIR="${SCRIPT_DIR}/../download" -CKPT_FILE="${DOWNLOAD_DIR}/unsupervised_detection_models/davis_best_model/model.best" -PWC_CKPT_FILE="${DOWNLOAD_DIR}/pwcnet-lg-6-2-multisteps-chairsthingsmix/pwcnet.ckpt-595000.data-00000-of-00001" -DATASET_FILE="${DOWNLOAD_DIR}/DAVIS" -RESULT_DIR="${SCRIPT_DIR}/../results/DAVIS" - - -echo "[INFO] start downloading data..." -mkdir -p ${DOWNLOAD_DIR} -( - cd ${DOWNLOAD_DIR} - if [ ! -f ${CKPT_FILE}.data* ]; then - echo "[INFO] no checkpoint file found. start downloading it." - wget https://rpg.ifi.uzh.ch/data/unsupervised_detection_models.zip - unzip unsupervised_detection_models.zip - rm unsupervised_detection_models.zip - fi - if [ ! -f ${PWC_CKPT_FILE} ]; then - echo "[INFO] no pwc checkpoint file found. start downloading it." - gdown --folder "https://drive.google.com/drive/folders/1gtGx_6MjUQC5lZpl6-Ia718Y_0pvcYou" - fi - if [ ! -e ${DATASET_FILE} ]; then - echo "[INFO] no DAVIS data found. start downloading it." - wget https://graphics.ethz.ch/Downloads/Data/Davis/DAVIS-data.zip - unzip DAVIS-data.zip - rm DAVIS-data.zip - fi -) -echo "[INFO] finished downloading." - - -echo "[INFO] start running a test..." -mkdir -p ${RESULT_DIR} -python3 test_generator.py \ ---dataset=DAVIS2016 \ ---ckpt_file=$CKPT_FILE \ ---flow_ckpt=$PWC_CKPT_FILE \ ---test_crop=0.9 \ ---test_temporal_shift=1 \ ---root_dir=$DATASET_FILE \ ---generate_visualization=True \ ---test_save_dir=${RESULT_DIR} -echo "[INFO] finished the test." diff --git a/scripts/test_FBMS59_raw.py b/scripts/test_FBMS59_raw.py new file mode 100644 index 0000000..b8ec618 --- /dev/null +++ b/scripts/test_FBMS59_raw.py @@ -0,0 +1,115 @@ +import os +import subprocess +import logging +import download_util + + +def main(): + # --- Fixed Parameters --- + dataset_name = "FBMS" + dataset_download_urls = [ + "https://lmb.informatik.uni-freiburg.de/resources/datasets/fbms/FBMS_Trainingset.zip", + "https://lmb.informatik.uni-freiburg.de/resources/datasets/fbms/FBMS_Testset.zip", + ] + model_ckpt_basedir = ( + "fbms_best_model" # this is the name under "unsupervised_detection_models" + ) + rootdir_name = dataset_name + + LOG_LEVEL = logging.INFO + TEST_CROP = 0.9 # FBMS default + TEST_TEMPORAL_SHIFT = 1 + GENERATE_VISUALIZATION = True + # --- End Fixed Parameters --- + + logging.basicConfig(level=LOG_LEVEL, format="[%(levelname)s] %(message)s") + + # --- Define Paths and URLs for FBMS-59 --- + script_dir = os.path.dirname(os.path.realpath(__file__)) + base_dir = os.path.abspath( + os.path.join(script_dir, "..") + ) # Go up one level from scripts/ + download_dir = os.path.join(base_dir, "download") + results_dir = os.path.join(base_dir, "results", dataset_name) + + # Model Checkpoint + model_ckpt_zip_url = "https://rpg.ifi.uzh.ch/data/unsupervised_detection_models.zip" + model_ckpt_base = os.path.join( + download_dir, + "unsupervised_detection_models", + model_ckpt_basedir, + "model.best", + ) # this will be used as the argument of test_generator.py + model_ckpt_path = ( + model_ckpt_base + ".data-00000-of-00001" + ) # actual checkpoint path. + + # PWCNet Checkpoint + pwc_gdown_folder_url = ( + "https://drive.google.com/drive/folders/1gtGx_6MjUQC5lZpl6-Ia718Y_0pvcYou" + ) + pwc_ckpt_path = os.path.join( + download_dir, + "pwcnet-lg-6-2-multisteps-chairsthingsmix", + "pwcnet.ckpt-595000.data-00000-of-00001", + ) + + # --- Ensure Prerequisites --- + logging.info(f"--- Checking Prerequisites for {dataset_name} ---") + os.makedirs(download_dir, exist_ok=True) + os.makedirs(results_dir, exist_ok=True) # Ensure results dir exists + + # 1. Dataset + if not download_util.ensure_dataset( + dataset_name, + dataset_download_urls, + download_dir, + ): + logging.error(f"Failed to prepare dataset {dataset_name}. Exiting.") + exit(1) + + # 2. Model Checkpoint + if not download_util.ensure_model_checkpoint(model_ckpt_zip_url, model_ckpt_path): + logging.error(f"Failed to prepare model checkpoint {model_ckpt_path}. Exiting.") + exit(1) + + # 3. PWCNet Checkpoint + if not download_util.ensure_pwc_checkpoint(pwc_gdown_folder_url, pwc_ckpt_path): + logging.error(f"Failed to prepare PWCNet checkpoint {pwc_ckpt_path}. Exiting.") + exit(1) + + logging.info("--- Prerequisites Met ---") + + # --- Run Test Generator --- + logging.info("Starting test generation...") + test_command = [ + "python3", + os.path.join(base_dir, "test_generator.py"), # Path to test_generator.py + f"--dataset={dataset_name}", + f"--ckpt_file={model_ckpt_base}", + f"--flow_ckpt={pwc_ckpt_path}", + f"--test_crop={TEST_CROP}", + f"--test_temporal_shift={TEST_TEMPORAL_SHIFT}", + f"--root_dir={download_dir}/{rootdir_name}", + f"--test_save_dir={results_dir}", + ] + if GENERATE_VISUALIZATION: + test_command.append("--generate_visualization") + + logging.info(f"Running command: {' '.join(test_command)}") + try: + # Run the command from the script's directory + subprocess.run(test_command, check=True, cwd=script_dir) + logging.info("Test generation finished successfully.") + except subprocess.CalledProcessError as e: + logging.error(f"Test generation failed with error code {e.returncode}.") + exit(1) + except FileNotFoundError: + logging.error( + f"Error: test_generator.py not found in {script_dir}. Make sure it exists." + ) + exit(1) + + +if __name__ == "__main__": + main()