@@ -187,17 +187,42 @@ function (ex::AbstractComposableExpression)(
187187 return x .* nan
188188 end
189189end
190+ # Method for all-Number arguments (scalars)
191+ function (ex:: AbstractComposableExpression )(x:: Number , _xs:: Vararg{Number,N} ) where {N}
192+ xs = (x, _xs... )
193+
194+ vectors = ntuple (i -> ValidVector ([float (xs[i])], true ), length (xs))
195+ return only (_get_value (ex (vectors... )))
196+ end
197+
190198function (ex:: AbstractComposableExpression )(
191- x:: ValidVector , _xs:: Vararg{ValidVector,N}
199+ x:: Union{ ValidVector,Number} , _xs:: Vararg{Union{ ValidVector,Number} ,N}
192200) where {N}
193201 xs = (x, _xs... )
194- valid = all (_is_valid, xs)
195- if ! valid
196- return ValidVector (_get_value (first (xs)), false )
197- else
198- X = Matrix (stack (map (_get_value, xs))' )
202+ sample_vector =
203+ let first_valid_vector_idx = findfirst (arg -> arg isa ValidVector, xs):: Int
204+ xs[first_valid_vector_idx]:: ValidVector
205+ end
206+
207+ # Convert Numbers to ValidVectors based on first ValidVector's size
208+ valid_args = ntuple (length (xs)) do i
209+ arg = xs[i]
210+ if arg isa ValidVector
211+ arg
212+ else
213+ # Convert Number to ValidVector with repeated values
214+ filled_array = similar (sample_vector. x)
215+ fill! (filled_array, arg)
216+ ValidVector (filled_array, true )
217+ end
218+ end
219+
220+ if all (_is_valid, valid_args)
221+ X = stack (map (_get_value, valid_args); dims= 1 )
199222 eval_options = get_eval_options (ex)
200223 return ValidVector (eval_tree_array (ex, X; eval_options))
224+ else
225+ return ValidVector (_get_value (first (valid_args)), false )
201226 end
202227end
203228function (ex:: AbstractComposableExpression{T} )() where {T}
@@ -252,6 +277,55 @@ _is_valid(x) = true
252277_get_value (x:: ValidVector ) = x. x
253278_get_value (x) = x
254279
280+ struct ValidVectorMixError <: Exception end
281+ struct ValidVectorAccessError <: Exception end
282+
283+ function Base. showerror (io:: IO , :: ValidVectorMixError )
284+ return print (
285+ io,
286+ """
287+ ValidVectorMixError: Cannot mix ValidVector with regular Vector.
288+
289+ ValidVector handles validity checks, auto-vectorization, and batching in template expressions.
290+ The .valid field tracks whether any upstream computation failed (false = failed, true = valid).
291+
292+ Wrap your vectors in ValidVector:
293+
294+ ```julia
295+ valid_ar1 = ValidVector(ar1, all(isfinite, ar1))
296+ valid_ar1 + valid_ar2
297+ ```
298+
299+ Alternatively, you can access the vector from a ValidVector with `my_validvector.x`,
300+ but you must be sure to propagate the `.valid` field. For example:
301+
302+ ```julia
303+ out = ar1 .+ valid_ar2.x
304+ ValidVector(out, all(isfinite, out) && valid_ar2.valid)
305+ ```
306+
307+ """ ,
308+ )
309+ end
310+
311+ function Base. showerror (io:: IO , :: ValidVectorAccessError )
312+ return print (
313+ io,
314+ """
315+ ValidVectorAccessError: ValidVector doesn't support direct array operations.
316+
317+ Use .x for data and .valid for validity:
318+
319+ ```julia
320+ valid_ar.x[1] # indexing
321+ length(valid_ar.x) # length
322+ valid_ar.valid # check validity (false = any upstream computation failed)
323+ ```
324+
325+ ValidVector handles validity/batching automatically in template expressions.""" ,
326+ )
327+ end
328+
255329# ! format: off
256330# First, binary operators:
257331for op in (
@@ -264,6 +338,9 @@ for op in (
264338 Base.$ (op)(x:: ValidVector , y:: ValidVector ) = apply_operator (Base.$ (op), x, y)
265339 Base.$ (op)(x:: ValidVector , y:: Number ) = apply_operator (Base.$ (op), x, y)
266340 Base.$ (op)(x:: Number , y:: ValidVector ) = apply_operator (Base.$ (op), x, y)
341+
342+ Base.$ (op)(:: ValidVector , :: AbstractVector ) = throw (ValidVectorMixError ())
343+ Base.$ (op)(:: AbstractVector , :: ValidVector ) = throw (ValidVectorMixError ())
267344 end
268345end
269346function Base. literal_pow (:: typeof (^ ), x:: ValidVector , :: Val{p} ) where {p}
@@ -286,6 +363,12 @@ for op in (
286363end
287364# ! format: on
288365
366+ Base. length (:: ValidVector ) = throw (ValidVectorAccessError ())
367+ Base. push! (:: ValidVector , :: Any ) = throw (ValidVectorAccessError ())
368+ for op in (:getindex , :size , :append! , :setindex! )
369+ @eval Base.$ (op)(:: ValidVector , :: Any... ) = throw (ValidVectorAccessError ())
370+ end
371+
289372# TODO : Support for 3-ary operators
290373
291374end
0 commit comments