diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 7b49969c02149..3e3db999c92ed 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -90,10 +90,8 @@ class ModelBase: use_temp_file: bool lazy: bool dry_run: bool - part_names: list[str] - is_safetensors: bool hparams: dict[str, Any] - tensor_names: set[str] | None + model_tensors: dict[str, Callable[[], Tensor]] gguf_writer: gguf.GGUFWriter model_name: str | None metadata_override: Path | None @@ -137,25 +135,8 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, self.dry_run = dry_run self.remote_hf_model_id = remote_hf_model_id self.sentence_transformers_dense_modules = sentence_transformers_dense_modules - if remote_hf_model_id is not None: - self.is_safetensors = True - - def get_remote_tensors() -> Iterator[tuple[str, Tensor]]: - logger.info(f"Using remote model with HuggingFace id: {remote_hf_model_id}") - remote_tensors = gguf.utility.SafetensorRemote.get_list_tensors_hf_model(remote_hf_model_id) - self.tensor_names = set(name for name in remote_tensors.keys()) - for name, remote_tensor in remote_tensors.items(): - yield (name, LazyTorchTensor.from_remote_tensor(remote_tensor)) - - self.get_tensors = get_remote_tensors - else: - prefix = "model" if not self.is_mistral_format else "consolidated" - self.part_names = ModelBase.get_model_part_names(self.dir_model, prefix, ".safetensors") - self.is_safetensors = len(self.part_names) > 0 - if not self.is_safetensors: - self.part_names = ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin") self.hparams = ModelBase.load_hparams(self.dir_model, self.is_mistral_format) if hparams is None else hparams - self.tensor_names = None + self.model_tensors = self.index_tensors(remote_hf_model_id=remote_hf_model_id) self.metadata_override = metadata_override self.model_name = model_name self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py @@ -171,6 +152,8 @@ def get_remote_tensors() -> Iterator[tuple[str, Tensor]]: logger.info(f"choosing --outtype bf16 from first tensor type ({first_tensor.dtype})") self.ftype = gguf.LlamaFileType.MOSTLY_BF16 + self.dequant_model() + # Configure GGUF Writer self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file, split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard) @@ -192,67 +175,215 @@ def find_hparam(self, keys: Iterable[str], optional: bool = False) -> Any: return None raise KeyError(f"could not find any of: {keys}") - def get_tensors(self) -> Iterator[tuple[str, Tensor]]: - tensor_names_from_parts: set[str] = set() + def index_tensors(self, remote_hf_model_id: str | None = None) -> dict[str, Callable[[], Tensor]]: + tensors: dict[str, Callable[[], Tensor]] = {} + + if remote_hf_model_id is not None: + is_safetensors = True + + logger.info(f"Using remote model with HuggingFace id: {remote_hf_model_id}") + remote_tensors = gguf.utility.SafetensorRemote.get_list_tensors_hf_model(remote_hf_model_id) + for name, remote_tensor in remote_tensors.items(): + tensors[name] = lambda r=remote_tensor: LazyTorchTensor.from_remote_tensor(r) + + return tensors + + prefix = "model" if not self.is_mistral_format else "consolidated" + part_names: list[str] = ModelBase.get_model_part_names(self.dir_model, prefix, ".safetensors") + is_safetensors: bool = len(part_names) > 0 + if not is_safetensors: + part_names = ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin") + + tensor_names_from_index: set[str] = set() if not self.is_mistral_format: - index_name = "model.safetensors" if self.is_safetensors else "pytorch_model.bin" + index_name = "model.safetensors" if is_safetensors else "pytorch_model.bin" index_name += ".index.json" index_file = self.dir_model / index_name if index_file.is_file(): - self.tensor_names = set() logger.info(f"gguf: loading model weight map from '{index_name}'") with open(index_file, "r", encoding="utf-8") as f: index: dict[str, Any] = json.load(f) weight_map = index.get("weight_map") if weight_map is None or not isinstance(weight_map, dict): raise ValueError(f"Can't load 'weight_map' from {index_name!r}") - self.tensor_names.update(weight_map.keys()) + tensor_names_from_index.update(weight_map.keys()) else: - self.tensor_names = tensor_names_from_parts weight_map = {} else: - self.tensor_names = tensor_names_from_parts weight_map = {} - for part_name in self.part_names: - logger.info(f"gguf: loading model part '{part_name}'") + for part_name in part_names: + logger.info(f"gguf: indexing model part '{part_name}'") ctx: ContextManager[Any] - if self.is_safetensors: + if is_safetensors: from safetensors import safe_open ctx = cast(ContextManager[Any], safe_open(self.dir_model / part_name, framework="pt", device="cpu")) else: ctx = contextlib.nullcontext(torch.load(str(self.dir_model / part_name), map_location="cpu", mmap=True, weights_only=True)) with ctx as model_part: - tensor_names_from_parts.update(model_part.keys()) + assert model_part is not None for name in model_part.keys(): - if self.is_safetensors: + if is_safetensors: if self.lazy: data = model_part.get_slice(name) - data = LazyTorchTensor.from_safetensors_slice(data) + data_gen = lambda data=data: LazyTorchTensor.from_safetensors_slice(data) # noqa: E731 else: data = model_part.get_tensor(name) + data_gen = lambda data=data: data # noqa: E731 else: data = model_part[name] if self.lazy: - data = LazyTorchTensor.from_eager(data) - yield name, data + data_gen = lambda data=data: LazyTorchTensor.from_eager(data) # noqa: E731 + else: + data_gen = lambda data=data: data # noqa: E731 + tensors[name] = data_gen # verify tensor name presence and identify potentially missing files - if len(tensor_names_from_parts.symmetric_difference(self.tensor_names)) > 0: - missing = sorted(self.tensor_names.difference(tensor_names_from_parts)) - extra = sorted(tensor_names_from_parts.difference(self.tensor_names)) - missing_files = sorted(set(weight_map[n] for n in missing if n in weight_map)) - if len(extra) == 0 and len(missing_files) > 0: - raise ValueError(f"Missing or incomplete model files: {missing_files}\n" - f"Missing tensors: {missing}") + if len(tensor_names_from_index) > 0: + tensor_names_from_parts = set(tensors.keys()) + if len(tensor_names_from_parts.symmetric_difference(tensor_names_from_index)) > 0: + missing = sorted(tensor_names_from_index.difference(tensor_names_from_parts)) + extra = sorted(tensor_names_from_parts.difference(tensor_names_from_index)) + missing_files = sorted(set(weight_map[n] for n in missing if n in weight_map)) + if len(extra) == 0 and len(missing_files) > 0: + raise ValueError(f"Missing or incomplete model files: {missing_files}\n" + f"Missing tensors: {missing}") + else: + raise ValueError("Mismatch between weight map and model parts for tensor names:\n" + f"Missing tensors: {missing}\n" + f"Extra tensors: {extra}") + + return tensors + + def dequant_model(self): + tensors_to_remove: list[str] = [] + new_tensors: dict[str, Callable[[], Tensor]] = {} + + if (quant_config := self.hparams.get("quantization_config")) and isinstance(quant_config, dict): + quant_method = quant_config.get("quant_method") + + def dequant_bitnet(weight: Tensor, scale: Tensor) -> Tensor: + weight = weight.view(torch.uint8) + orig_shape = weight.shape + + shift = torch.tensor([0, 2, 4, 6], dtype=torch.uint8).reshape((4, *(1 for _ in range(len(orig_shape))))) + data = weight.unsqueeze(0).expand((4, *orig_shape)) >> shift + data = data & 3 + data = (data.float() - 1).reshape((orig_shape[0] * 4, *orig_shape[1:])) + + # The scale is inverted + return data / scale.float() + + def dequant_simple(weight: Tensor, scale: Tensor) -> Tensor: + scale = scale.float() + + if (weight_block_size := quant_config.get("weight_block_size")): + # TODO: make sure it's a list of integers + for i, size in enumerate(weight_block_size): + scale = scale.repeat_interleave(size, i) + # unpad the scale (e.g. when the tensor size isn't a multiple of the block size) + scale = scale[tuple(slice(0, size) for size in weight.shape)] + + return weight.float() * scale + + # ref: https://github.com/ModelCloud/GPTQModel/blob/037c5c0f6c9e33c500d975b038d02e7ca437546d/gptqmodel/nn_modules/qlinear/__init__.py#L437-L476 + def dequant_gptq(g_idx: Tensor, qweight: Tensor, qzeros: Tensor, scales: Tensor) -> Tensor: + bits = quant_config["bits"] + assert bits in (2, 3, 4, 8) + assert qweight.dtype == qzeros.dtype + maxq = (2 ** bits) - 1 + weight = None + zeros = None + pack_dtype_bits = qweight.dtype.itemsize * 8 + + if bits in [2, 4, 8]: + pack_factor = pack_dtype_bits // bits + wf = torch.tensor(list(range(0, pack_dtype_bits, bits)), dtype=torch.int32).unsqueeze(0) + if self.lazy: + wf = LazyTorchTensor.from_eager(wf) + + zeros = torch.bitwise_right_shift( + qzeros.unsqueeze(2).expand(-1, -1, pack_factor), + wf.unsqueeze(0) + ).to(torch.int16 if bits == 8 else torch.int8) + zeros = torch.bitwise_and(zeros, maxq).reshape(scales.shape) + + weight = torch.bitwise_and( + torch.bitwise_right_shift( + qweight.unsqueeze(1).expand(-1, pack_factor, -1), + wf.unsqueeze(-1) + ).to(torch.int16 if bits == 8 else torch.int8), + maxq + ) + elif bits == 3: + raise NotImplementedError("3-bit gptq dequantization is not yet implemented") + + assert weight is not None + assert zeros is not None + + weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2]) + + # gptq_v2 doesn't need to offset zeros + if quant_config.get("checkpoint_format", "gptq") == "gptq": + zeros += 1 + + return (scales[g_idx].float() * (weight - zeros[g_idx]).float()).T + + if quant_method == "bitnet": + for name in self.model_tensors.keys(): + if name.endswith(".weight_scale"): + weight_name = name.removesuffix("_scale") + w = self.model_tensors[weight_name] + s = self.model_tensors[name] + self.model_tensors[weight_name] = lambda w=w, s=s: dequant_bitnet(w(), s()) + tensors_to_remove.append(name) + elif quant_method == "fp8": + for name in self.model_tensors.keys(): + if name.endswith(".weight_scale_inv"): + weight_name = name.removesuffix("_scale_inv") + w = self.model_tensors[weight_name] + s = self.model_tensors[name] + self.model_tensors[weight_name] = lambda w=w, s=s: dequant_simple(w(), s()) + tensors_to_remove.append(name) + elif quant_method == "gptq": + for name in self.model_tensors.keys(): + if name.endswith(".qweight"): + base_name = name.removesuffix(".qweight") + g_idx = self.model_tensors[base_name + ".g_idx"] + qweight = self.model_tensors[base_name + ".qweight"] + qzeros = self.model_tensors[base_name + ".qzeros"] + scales = self.model_tensors[base_name + ".scales"] + new_tensors[base_name + ".weight"] = ( + lambda g=g_idx, z=qzeros, w=qweight, s=scales: dequant_gptq( + g(), w(), z(), s() + ) + ) + tensors_to_remove += [ + base_name + n + for n in ( + ".g_idx", + ".qzeros", + ".qweight", + ".scales", + ) + ] else: - raise ValueError("Mismatch between weight map and model parts for tensor names:\n" - f"Missing tensors: {missing}\n" - f"Extra tensors: {extra}") + raise NotImplementedError(f"Quant method is not yet supported: {quant_method!r}") + + for name in tensors_to_remove: + if name in self.model_tensors: + del self.model_tensors[name] + + for name, value in new_tensors.items(): + self.model_tensors[name] = value + + def get_tensors(self) -> Iterator[tuple[str, Tensor]]: + for name, gen in self.model_tensors.items(): + yield name, gen() def format_tensor_name(self, key: gguf.MODEL_TENSOR, bid: int | None = None, suffix: str = ".weight") -> str: if key not in gguf.MODEL_TENSORS[self.model_arch]: @@ -4381,27 +4512,6 @@ def set_gguf_parameters(self): self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) self.gguf_writer.add_rope_scaling_factor(1.0) - _has_tok_embd = False - - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - del bid # unused - - output_name = self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT) - tok_embd_name = self.format_tensor_name(gguf.MODEL_TENSOR.TOKEN_EMBD) - - new_name = self.map_tensor_name(name) - - # assuming token_embd.weight is seen before output.weight - if not self._has_tok_embd and new_name == self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT): - # even though the tensor file(s) does not contain the word embeddings they are still in the weight map - if self.tensor_names and "transformer.wte.weight" in self.tensor_names: - logger.debug(f"{tok_embd_name} not found before {output_name}, assuming they are tied") - self.tensor_names.remove("transformer.wte.weight") - elif new_name == tok_embd_name: - self._has_tok_embd = True - - return [(new_name, data_torch)] - @ModelBase.register("InternLM2ForCausalLM") class InternLM2Model(TextModel): diff --git a/examples/model-conversion/scripts/causal/run-org-model.py b/examples/model-conversion/scripts/causal/run-org-model.py index 9444c713d03ab..7fb55e9af1f52 100755 --- a/examples/model-conversion/scripts/causal/run-org-model.py +++ b/examples/model-conversion/scripts/causal/run-org-model.py @@ -138,7 +138,7 @@ def fn(_m, input, output): "Model path must be specified either via --model-path argument or MODEL_PATH environment variable" ) -config = AutoConfig.from_pretrained(model_path) +config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) print("Model type: ", config.model_type) print("Vocab size: ", config.vocab_size) @@ -148,8 +148,8 @@ def fn(_m, input, output): print("EOS token id: ", config.eos_token_id) print("Loading model and tokenizer using AutoTokenizer:", model_path) -tokenizer = AutoTokenizer.from_pretrained(model_path) -config = AutoConfig.from_pretrained(model_path) +tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) +config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) if unreleased_model_name: model_name_lower = unreleased_model_name.lower() @@ -171,7 +171,7 @@ def fn(_m, input, output): exit(1) else: model = AutoModelForCausalLM.from_pretrained( - model_path, device_map="auto", offload_folder="offload" + model_path, device_map="auto", offload_folder="offload", trust_remote_code=True ) for name, module in model.named_modules(): diff --git a/ggml/src/ggml-cuda/argsort.cu b/ggml/src/ggml-cuda/argsort.cu index 607ded8558b45..6e7b90d42783f 100644 --- a/ggml/src/ggml-cuda/argsort.cu +++ b/ggml/src/ggml-cuda/argsort.cu @@ -1,5 +1,81 @@ #include "argsort.cuh" +#ifdef GGML_CUDA_USE_CUB +# include +using namespace cub; +#endif // GGML_CUDA_USE_CUB + +static __global__ void init_indices(int * indices, const int ncols, const int nrows) { + const int col = blockIdx.x * blockDim.x + threadIdx.x; + const int row = blockIdx.y; + + if (col < ncols && row < nrows) { + indices[row * ncols + col] = col; + } +} + +static __global__ void init_offsets(int * offsets, const int ncols, const int nrows) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx <= nrows) { + offsets[idx] = idx * ncols; + } +} + +#ifdef GGML_CUDA_USE_CUB +static void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool, + const float * x, + int * dst, + const int ncols, + const int nrows, + ggml_sort_order order, + cudaStream_t stream) { + ggml_cuda_pool_alloc temp_indices_alloc(pool, ncols * nrows); + ggml_cuda_pool_alloc temp_keys_alloc(pool, ncols * nrows); + ggml_cuda_pool_alloc offsets_alloc(pool, nrows + 1); + + int * temp_indices = temp_indices_alloc.get(); + float * temp_keys = temp_keys_alloc.get(); + int * d_offsets = offsets_alloc.get(); + + static const int block_size = 256; + const dim3 grid_size((ncols + block_size - 1) / block_size, nrows); + init_indices<<>>(temp_indices, ncols, nrows); + + const dim3 offset_grid((nrows + block_size - 1) / block_size); + init_offsets<<>>(d_offsets, ncols, nrows); + + cudaMemcpyAsync(temp_keys, x, ncols * nrows * sizeof(float), cudaMemcpyDeviceToDevice, stream); + + size_t temp_storage_bytes = 0; + + if (order == GGML_SORT_ORDER_ASC) { + DeviceSegmentedRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) + temp_indices, dst, // values (indices) + ncols * nrows, nrows, // num items, num segments + d_offsets, d_offsets + 1, 0, sizeof(float) * 8, // all bits + stream); + } else { + DeviceSegmentedRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices, + dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, 0, + sizeof(float) * 8, stream); + } + + ggml_cuda_pool_alloc temp_storage_alloc(pool, temp_storage_bytes); + void * d_temp_storage = temp_storage_alloc.get(); + + if (order == GGML_SORT_ORDER_ASC) { + DeviceSegmentedRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst, + ncols * nrows, nrows, d_offsets, d_offsets + 1, 0, sizeof(float) * 8, + stream); + } else { + DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, + temp_indices, dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, + 0, sizeof(float) * 8, stream); + } +} +#endif // GGML_CUDA_USE_CUB + +// Bitonic sort implementation template static inline __device__ void ggml_cuda_swap(T & a, T & b) { T tmp = a; @@ -65,7 +141,12 @@ static int next_power_of_2(int x) { return n; } -static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, const int nrows, ggml_sort_order order, cudaStream_t stream) { +static void argsort_f32_i32_cuda_bitonic(const float * x, + int * dst, + const int ncols, + const int nrows, + ggml_sort_order order, + cudaStream_t stream) { // bitonic sort requires ncols to be power of 2 const int ncols_pad = next_power_of_2(ncols); @@ -77,9 +158,11 @@ static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, co GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb); if (order == GGML_SORT_ORDER_ASC) { - k_argsort_f32_i32<<>>(x, dst, ncols, ncols_pad); + k_argsort_f32_i32 + <<>>(x, dst, ncols, ncols_pad); } else if (order == GGML_SORT_ORDER_DESC) { - k_argsort_f32_i32<<>>(x, dst, ncols, ncols_pad); + k_argsort_f32_i32 + <<>>(x, dst, ncols, ncols_pad); } else { GGML_ABORT("fatal error"); } @@ -100,5 +183,18 @@ void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0]; - argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, order, stream); +#ifdef GGML_CUDA_USE_CUB + const int ncols_pad = next_power_of_2(ncols); + const size_t shared_mem = ncols_pad * sizeof(int); + const size_t max_shared_mem = ggml_cuda_info().devices[ggml_cuda_get_device()].smpb; + + if (shared_mem > max_shared_mem || ncols > 1024) { + ggml_cuda_pool & pool = ctx.pool(); + argsort_f32_i32_cuda_cub(pool, src0_d, (int *) dst_d, ncols, nrows, order, stream); + } else { + argsort_f32_i32_cuda_bitonic(src0_d, (int *) dst_d, ncols, nrows, order, stream); + } +#else + argsort_f32_i32_cuda_bitonic(src0_d, (int *) dst_d, ncols, nrows, order, stream); +#endif } diff --git a/ggml/src/ggml-cuda/binbcast.cu b/ggml/src/ggml-cuda/binbcast.cu index 60240102741f3..0e6d777b1e64a 100644 --- a/ggml/src/ggml-cuda/binbcast.cu +++ b/ggml/src/ggml-cuda/binbcast.cu @@ -272,7 +272,7 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor * const uint3 ne12 = init_fastdiv_values((uint32_t) cne1[2]); const uint3 ne13 = init_fastdiv_values((uint32_t) cne1[3]); - if (block_nums.z > 65535) { + if (block_nums.z > 65535 || block_nums.y > 65535) { int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size; const uint3 prod_012 = init_fastdiv_values((uint32_t) (ne0 * ne1 * ne2)); const uint3 prod_01 = init_fastdiv_values((uint32_t) (ne0 * ne1)); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index f5a6a751acfd5..bc396b521af07 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3642,8 +3642,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_SUM: return ggml_is_contiguous_rows(op->src[0]); case GGML_OP_ARGSORT: - // TODO: Support arbitrary column width +#ifndef GGML_CUDA_USE_CUB return op->src[0]->ne[0] <= 1024; +#else + return true; +#endif case GGML_OP_SUM_ROWS: case GGML_OP_MEAN: case GGML_OP_GROUP_NORM: diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 991c62597962d..9eb2b66879c0b 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -6407,6 +6407,7 @@ static std::vector> make_test_cases_eval() { add_test_bin_bcast(type, {1, 1, 640, 1}, {32, 32, 1, 1}); add_test_bin_bcast(type, {5120, 1, 1, 1}, {1, 256, 1, 1}); add_test_bin_bcast(type, {640, 1, 1, 1}, {1, 1, 1, 1}); + add_test_bin_bcast(type, {64, 262144, 1, 1}, {1, 1, 1, 1}); //add_test_bin_bcast(type, {3, 3, 2560, 1280}, {1, 1, 1, 1}); //add_test_bin_bcast(type, {3, 3, 2560, 1280}, {2, 1, 1, 1}); } diff --git a/tools/server/public/index.html.gz b/tools/server/public/index.html.gz index 7b56d87e430ea..026b53b28632f 100644 Binary files a/tools/server/public/index.html.gz and b/tools/server/public/index.html.gz differ diff --git a/tools/server/webui/src/routes/+page.svelte b/tools/server/webui/src/routes/+page.svelte index 2cd2d5c37313a..cd18dabccb9de 100644 --- a/tools/server/webui/src/routes/+page.svelte +++ b/tools/server/webui/src/routes/+page.svelte @@ -2,6 +2,9 @@ import { ChatScreen } from '$lib/components/app'; import { chatStore, isInitialized } from '$lib/stores/chat.svelte'; import { onMount } from 'svelte'; + import { page } from '$app/state'; + + let qParam = $derived(page.url.searchParams.get('q')); onMount(async () => { if (!isInitialized) { @@ -9,6 +12,11 @@ } chatStore.clearActiveConversation(); + + if (qParam !== null) { + await chatStore.createConversation(); + await chatStore.sendMessage(qParam); + } });