diff --git a/src/indexing.jl b/src/indexing.jl index 305481e..9f40765 100644 --- a/src/indexing.jl +++ b/src/indexing.jl @@ -145,3 +145,14 @@ function index(indexer::Indexer) # check if all relevant files are saved _check_all_files_are_saved(indexer.config.index_path) end + +function Base.show(io::IO, ::MIME"text/plain", indexer::Indexer) + print(io, "ColBERT Indexer:\n") + print(io, " collection size: $(length(indexer.collection)) documents\n") + print(io, " checkpoint: $(indexer.config.checkpoint)\n") + collection_path = indexer.config.collection + if collection_path isa String && !isempty(collection_path) + print(io, " collection path: $(collection_path)\n") + end + print(io, " index path: $(indexer.config.index_path)\n") +end diff --git a/src/infra/config.jl b/src/infra/config.jl index 06fbc5d..d8cb243 100644 --- a/src/infra/config.jl +++ b/src/infra/config.jl @@ -63,7 +63,7 @@ Base.@kwdef struct ColBERTConfig query_token::String = "[Q]" doc_token::String = "[D]" - # resource settings + # resource settings checkpoint::String = "colbert-ir/colbertv2.0" collection::Union{String, Vector{String}} = "" @@ -88,3 +88,31 @@ Base.@kwdef struct ColBERTConfig nprobe::Int = 2 ncandidates::Int = 8192 end + +function Base.show(io::IO, ::MIME"text/plain", config::ColBERTConfig) + print(io, "ColBERTConfig:\n") + print(io, " model:\n") + print(io, " checkpoint: $(config.checkpoint)\n") + print(io, " dim: $(config.dim)\n") + print(io, " documents:\n") + print(io, + " collection: $(config.collection isa String ? config.collection : "$(length(config.collection)) documents")\n") + print(io, " max length: $(config.doc_maxlen)\n") + print(io, " mask punctuation: $(config.mask_punctuation)\n") + print(io, " queries:\n") + print(io, " max length: $(config.query_maxlen)\n") + print(io, " attend to mask: $(config.attend_to_mask_tokens)\n") + print(io, " indexing:\n") + print(io, " path: $(config.index_path)\n") + print(io, " batch size: $(config.index_bsize)\n") + print(io, " chunk size: $(config.chunksize)\n") + print(io, " compression bits: $(config.nbits)\n") + print(io, " kmeans iterations: $(config.kmeans_niters)\n") + print(io, " search:\n") + print(io, " nprobe: $(config.nprobe)\n") + print(io, " ncandidates: $(config.ncandidates)\n") + print(io, " hardware:\n") + print(io, " gpu: $(config.use_gpu)\n") + print(io, " rank: $(config.rank)\n") + print(io, " nranks: $(config.nranks)\n") +end diff --git a/src/searching.jl b/src/searching.jl index e617931..4af0e6a 100644 --- a/src/searching.jl +++ b/src/searching.jl @@ -126,3 +126,12 @@ function search(searcher::Searcher, query::String, k::Int) pids, scores = pids[indices], scores[indices] pids[1:k], scores[1:k] end + +function Base.show(io::IO, ::MIME"text/plain", searcher::Searcher) + print(io, "ColBERT Searcher:\n") + print(io, " checkpoint: $(searcher.config.checkpoint)\n") + print(io, " index path: $(searcher.config.index_path)\n") + print(io, " embeddings:\n") + print(io, " total: $(sum(searcher.doclens))\n") + print(io, " centroids: $(size(searcher.centroids,2))\n") +end diff --git a/test/runtests.jl b/test/runtests.jl index c6d6df1..72c4ec5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -21,7 +21,7 @@ const FLOAT_TYPES = [Float16, Float32, Float64] include("indexing/codecs/residual.jl") include("indexing/collection_indexer.jl") -# modelling operations +# modelling operations include("modelling/tokenization/tokenizer_utils.jl") include("modelling/embedding_utils.jl") @@ -29,6 +29,9 @@ include("modelling/embedding_utils.jl") include("searching.jl") include("search/ranking.jl") +# show operations +include("show_methods.jl") + # utils include("utils.jl") diff --git a/test/show_methods.jl b/test/show_methods.jl new file mode 100644 index 0000000..05c0fc8 --- /dev/null +++ b/test/show_methods.jl @@ -0,0 +1,14 @@ +@testset "show methods" begin + mktempdir() do dir + config = ColBERTConfig( + checkpoint = "dummy-checkpoint", + index_path = dir, + collection = ["doc1", "doc2"] + ) + + str = sprint(show, MIME("text/plain"), config) + @test occursin(" checkpoint: dummy-checkpoint", str) + @test occursin(" path: $dir", str) + @test occursin(" collection: 2 documents", str) + end +end